Comments (3)
found this repo today and will be attempting as well
from llmtest_needleinahaystack.
@hijkzzz @disperaller @RahulSinghalChicago hey guys, I have some findings. But I found the score may be a bit to hight. Do you have a better idea?
base on the openai provider, you can add your own provider like this: (namely local_llama.py
in this case)
--- ./providers/openai.py
+++ ./providers/local_llama.py
@@ -1,16 +1,19 @@
import os
from operator import itemgetter
from typing import Optional
+import torch
-from openai import AsyncOpenAI
-from langchain_openai import ChatOpenAI
+from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
+
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
import tiktoken
from .model import ModelProvider
-class OpenAI(ModelProvider):
+class LocalLlama(ModelProvider):
"""
A wrapper class for interacting with OpenAI's API, providing methods to encode text, generate prompts,
evaluate models, and create LangChain runnables for language model interactions.
@@ -25,7 +28,7 @@
temperature = 0)
def __init__(self,
- model_name: str = "gpt-3.5-turbo-0125",
+ model_name: str = "meta/llama-2-7b-chat-hf",
model_kwargs: dict = DEFAULT_MODEL_KWARGS):
"""
Initializes the OpenAI model provider with a specific model.
@@ -37,15 +40,18 @@
Raises:
ValueError: If NIAH_MODEL_API_KEY is not found in the environment.
"""
- api_key = os.getenv('NIAH_MODEL_API_KEY')
- if (not api_key):
- raise ValueError("NIAH_MODEL_API_KEY must be in env.")
-
- self.model_name = model_name
+ self.model_or_path = model_name
+ self.model_name = model_name.split("/")[-1]
self.model_kwargs = model_kwargs
- self.api_key = api_key
- self.model = AsyncOpenAI(api_key=self.api_key)
- self.tokenizer = tiktoken.encoding_for_model(self.model_name)
+
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.model_or_path,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ )
+
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_or_path)
async def evaluate_model(self, prompt: str) -> str:
"""
@@ -57,12 +63,18 @@
Returns:
str: The content of the model's response to the prompt.
"""
- response = await self.model.chat.completions.create(
- model=self.model_name,
- messages=prompt,
- **self.model_kwargs
- )
- return response.choices[0].message.content
+ MAX_GEN_LENGTH = 128
+ tokenized_prompts = self.tokenizer(prompt, return_tensors="pt")
+ input_ids = tokenized_prompts.input_ids.cuda()
+
+ generation_output = self.model.generate(
+ input_ids,
+ max_new_tokens=MAX_GEN_LENGTH,
+ use_cache=True,
+ return_dict_in_generate=True)
+
+ output = self.tokenizer.decode(generation_output.sequences[:,input_ids.shape[1]:][0])
+ return output
def generate_prompt(self, context: str, retrieval_question: str) -> str | list[dict[str, str]]:
"""
@@ -75,19 +87,16 @@
Returns:
list[dict[str, str]]: A list of dictionaries representing the structured prompt, including roles and content for system and user messages.
"""
- return [{
- "role": "system",
- "content": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
- },
- {
- "role": "user",
- "content": context
- },
- {
- "role": "user",
- "content": f"{retrieval_question} Don't give information outside the document or repeat your findings"
- }]
-
+ return f"""
+<s>[INST] <<SYS>>
+You are a helpful AI bot that answers questions for a user. Keep your response short and direct
+<</SYS>>
+{ context }
+
+{retrieval_question} Don't give information outside the document or repeat your findings
+[/INST]</s>
+"""
+
def encode_text_to_tokens(self, text: str) -> list[int]:
"""
Encodes a given text string to a sequence of tokens using the model's tokenizer.
add a entry in run.py like this:
diff --git a/needlehaystack/run.py b/needlehaystack/run.py
index 8edbccb..f5b6783 100644
--- a/needlehaystack/run.py
+++ b/needlehaystack/run.py
@@ -6,7 +6,7 @@ from jsonargparse import CLI
from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator
-from .providers import Anthropic, ModelProvider, OpenAI, Cohere
+from .providers import Anthropic, ModelProvider, OpenAI, Cohere, LocalLlama
load_dotenv()
@@ -65,6 +65,8 @@ def get_model_to_test(args: CommandArgs) -> ModelProvider:
return Anthropic(model_name=args.model_name)
case "cohere":
return Cohere(model_name=args.model_name)
+ case "local":
+ return LocalLlama(model_name=args.model_name)
case _:
raise ValueError(f"Invalid provider: {args.provider}")
from llmtest_needleinahaystack.
+1
from llmtest_needleinahaystack.
Related Issues (20)
- Replace os.path with Pathlib
- Update package Anthropic HOT 2
- Anthropic Naming Conflict Error HOT 2
- Implement Docker for testing HOT 1
- Code optimizations
- Model kwargs support HOT 1
- Add Makefile target for resetting run results HOT 1
- Standard Tokenizer HOT 12
- Convert the repository to a PyPi package HOT 1
- Remove passing of API keys as parameters and read them from environment variables HOT 1
- multi-needle-eval-pizza-3 dataset not found HOT 1
- I was wondering about the evaluation method HOT 2
- [Feature Proposal] Multi-needle in a haystack HOT 2
- does it run at all? Basic commands failed to run as per the README. HOT 1
- Question: Can the Haystack have variations? HOT 3
- Possibility to specify custom API endpoint address? HOT 4
- How can we cite the Needle-in-a-Haystack? HOT 1
- add base_url env in openai provider - to support OpenAI compatibility local inference like - ollama, tgi, etc
- Different prompts in providers - I just wonder why cohere don't have "Don't give information outside the document or repeat your findings" and does it make a bit difference?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from llmtest_needleinahaystack.