Giter Club home page Giter Club logo

Comments (3)

RahulSinghalChicago avatar RahulSinghalChicago commented on July 23, 2024 3

found this repo today and will be attempting as well

from llmtest_needleinahaystack.

66RING avatar 66RING commented on July 23, 2024 2

@hijkzzz @disperaller @RahulSinghalChicago hey guys, I have some findings. But I found the score may be a bit to hight. Do you have a better idea?

base on the openai provider, you can add your own provider like this: (namely local_llama.py in this case)

--- ./providers/openai.py
+++ ./providers/local_llama.py
@@ -1,16 +1,19 @@
 import os
 from operator import itemgetter
 from typing import Optional
+import torch
 
-from openai import AsyncOpenAI
-from langchain_openai import ChatOpenAI  
+from langchain_openai import ChatOpenAI
 from langchain.prompts import PromptTemplate
+
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
 import tiktoken
 
 from .model import ModelProvider
 
 
-class OpenAI(ModelProvider):
+class LocalLlama(ModelProvider):
     """
     A wrapper class for interacting with OpenAI's API, providing methods to encode text, generate prompts,
     evaluate models, and create LangChain runnables for language model interactions.
@@ -25,7 +28,7 @@
                                       temperature = 0)
 
     def __init__(self,
-                 model_name: str = "gpt-3.5-turbo-0125",
+                 model_name: str = "meta/llama-2-7b-chat-hf",
                  model_kwargs: dict = DEFAULT_MODEL_KWARGS):
         """
         Initializes the OpenAI model provider with a specific model.
@@ -37,15 +40,18 @@
         Raises:
             ValueError: If NIAH_MODEL_API_KEY is not found in the environment.
         """
-        api_key = os.getenv('NIAH_MODEL_API_KEY')
-        if (not api_key):
-            raise ValueError("NIAH_MODEL_API_KEY must be in env.")
-
-        self.model_name = model_name
+        self.model_or_path = model_name
+        self.model_name = model_name.split("/")[-1]
         self.model_kwargs = model_kwargs
-        self.api_key = api_key
-        self.model = AsyncOpenAI(api_key=self.api_key)
-        self.tokenizer = tiktoken.encoding_for_model(self.model_name)
+
+        self.model = AutoModelForCausalLM.from_pretrained(
+                self.model_or_path,
+                device_map="auto",
+                torch_dtype=torch.bfloat16,
+                attn_implementation="flash_attention_2",
+            )
+
+        self.tokenizer = AutoTokenizer.from_pretrained(self.model_or_path)
     
     async def evaluate_model(self, prompt: str) -> str:
         """
@@ -57,12 +63,18 @@
         Returns:
             str: The content of the model's response to the prompt.
         """
-        response = await self.model.chat.completions.create(
-                model=self.model_name,
-                messages=prompt,
-                **self.model_kwargs
-            )
-        return response.choices[0].message.content
+        MAX_GEN_LENGTH = 128
+        tokenized_prompts = self.tokenizer(prompt, return_tensors="pt")
+        input_ids = tokenized_prompts.input_ids.cuda()
+
+        generation_output = self.model.generate(
+            input_ids,
+            max_new_tokens=MAX_GEN_LENGTH,
+            use_cache=True,
+            return_dict_in_generate=True)
+
+        output = self.tokenizer.decode(generation_output.sequences[:,input_ids.shape[1]:][0])
+        return output
     
     def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]:
         """
@@ -75,19 +87,16 @@
         Returns:
             list[dict[str, str]]: A list of dictionaries representing the structured prompt, including roles and content for system and user messages.
         """
-        return [{
-                "role": "system",
-                "content": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
-            },
-            {
-                "role": "user",
-                "content": context
-            },
-            {
-                "role": "user",
-                "content": f"{retrieval_question} Don't give information outside the document or repeat your findings"
-            }]
-    
+        return f"""
+<s>[INST] <<SYS>>
+You are a helpful AI bot that answers questions for a user. Keep your response short and direct
+<</SYS>>
+{ context }
+
+{retrieval_question} Don't give information outside the document or repeat your findings
+[/INST]</s>
+"""
+
     def encode_text_to_tokens(self, text: str) -> list[int]:
         """
         Encodes a given text string to a sequence of tokens using the model's tokenizer.

add a entry in run.py like this:

diff --git a/needlehaystack/run.py b/needlehaystack/run.py
index 8edbccb..f5b6783 100644
--- a/needlehaystack/run.py
+++ b/needlehaystack/run.py
@@ -6,7 +6,7 @@ from jsonargparse import CLI
 
 from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester
 from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator
-from .providers import Anthropic, ModelProvider, OpenAI, Cohere
+from .providers import Anthropic, ModelProvider, OpenAI, Cohere, LocalLlama
 
 load_dotenv()
 
@@ -65,6 +65,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider:
             return Anthropic(model_name=args.model_name)
         case "cohere":
             return Cohere(model_name=args.model_name)
+        case "local":
+            return LocalLlama(model_name=args.model_name)
         case _:
             raise ValueError(f"Invalid provider: {args.provider}")

from llmtest_needleinahaystack.

disperaller avatar disperaller commented on July 23, 2024

+1

from llmtest_needleinahaystack.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.