Comments (2)
Code below should mimic the OpenAI API, but using a Huggingface model:
from langchain.schema import SystemMessage, HumanMessage, AIMessage, BaseMessage
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList
from typing import List, Union, Tuple
import uuid
model_name_or_path = "TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ"
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=False,
revision="main",
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
class WordStoppingCriteria(StoppingCriteria):
def __init__(self, stop_ids):
self.stop_ids = stop_ids
def __call__(self, input_ids, scores):
l = len(self.stop_ids)
return len(input_ids[0]) > l and input_ids[0][-l:].tolist() == self.stop_ids
def generate_stopping_critera(stop_word:str):
stop_word_ids = tokenizer.encode(stop_word, add_special_tokens=False)
return WordStoppingCriteria(stop_word_ids)
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, token_ids_to_bias):
self.token_ids_to_bias = token_ids_to_bias
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for token_ids, bias in self.token_ids_to_bias.items():
for token_id in token_ids:
scores[:, token_id] = torch.clamp(scores[:, token_id] * bias, min=0)
return scores
def generate_custom_logits_processor(mapping:dict):
token_ids_to_bias = {tuple(tokenizer.encode(word, add_special_tokens=False)): bias for word, bias in mapping.items()}
return LogitsBiasProcessor(token_ids_to_bias)
def generate(input_text, **args):
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(model.device)
seed = args.get('seed', None)
if seed:
torch.manual_seed(seed)
stopping_criteria = []
if 'stop' in args:
stop_words = args["stop"]
if isinstance(stop_words, str):
stop_words = [stop_words]
for stop_word in stop_words:
stopping_criteria.append(generate_stopping_critera(stop_word))
logit_processors = []
if 'logit_bias' in args:
mapping = args["logit_bias"]
logit_processors.append(generate_custom_logits_processor(mapping))
outputs = model.generate(
input_ids,
do_sample=True,
max_length=args.get('max_tokens', 50) + len(input_ids[0]),
repetition_penalty=args.get('repetition_penalty', 1.0),
temperature=args.get('temperature', 1.0),
top_k=args.get('top_k', 0),
top_p=args.get('top_p', 1.0),
num_return_sequences=args.get('n', 1),
pad_token_id=model.config.eos_token_id,
stopping_criteria=StoppingCriteriaList(stopping_criteria),
logits_processor=LogitsProcessorList(logit_processors),
)
response = {
'id': str(uuid.uuid4()),
'object': 'chat.completion',
'created': int(time.time()),
'model': model.__class__.__name__,
'choices': []
}
total_tokens = sum(len(output) - len(input_ids[0]) for output in outputs)
response['usage'] = {
'prompt_tokens': len(input_ids[0]),
'completion_tokens': total_tokens,
'total_tokens': total_tokens + len(input_ids[0]),
}
for i, output in enumerate(outputs):
num_input_tokens = len(input_ids[0])
generated_text = tokenizer.decode(output[num_input_tokens:], skip_special_tokens=True)
with torch.no_grad():
logits = model(output.unsqueeze(0))[0]
logprobs = torch.log_softmax(logits, dim=-1)[0]
response['choices'].append({
'index': i,
'message': {
'role': 'assistant',
'content': generated_text,
},
'logprobs': {
'content': [
{
'token': tokenizer.decode([output[j]]),
'logprob': float(logprobs[j-1][output[j]]),
'bytes': [output[j].item()],
'top_logprobs': sorted(
[
{
'token': tokenizer.decode([k]),
'logprob': float(logprobs[j-1][k]),
'bytes': [k.item()]
} for k in logprobs[j-1].topk(args.get('top_logprobs', 0)).indices
],
key=lambda x: x['logprob'],
reverse=True
)
} for j in range(num_input_tokens, len(output))
]
} if args.get('logprobs', False) else None,
'finish_reason': 'stop' if output[-1] == model.config.eos_token_id else 'length',
})
return response
from flare.
Hello, I attempt to solve this by using https://github.com/xusenlinzy/api-for-open-llm this project,but this project use openai.completion api and this response contains completion.choices[0].logprobs.I think most open LLM don't have api to give the answer for this parameter,So using the openllm project will return null for the logprobs.which will make mistake .
So ,I think most open llm will not work well
from flare.
Related Issues (20)
- prep.py lacks modules import HOT 2
- How to add a new dataset? HOT 1
- Could you share WikiAsp dataset used in the experiments? HOT 2
- How to merge document lists retrieved from multiple queries? HOT 1
- How do you get token probability through openai's api HOT 1
- Readme needs to indicate unzipping wikipedia documents HOT 1
- prep.py processing of psgs_w100.tsv document is incorrect HOT 1
- code for evaluation HOT 2
- 2WikiMultihopQA dataset in the experiment
- can I use open source model like llama, mistral from huggingface ecosystem HOT 1
- StrategyQA and ASQA dataset
- So many errors about JSONDecodeError HOT 1
- Would you release the code for FLARE_instruct?
- So buggy
- Is the “torbo” class model not supported?
- openai API "text-davinci-003" is deprecated HOT 1
- what type of DPR used for psgs_w100.tsv? HOT 1
- Request on data and corresponding config files
- Which version of openai are you using?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flare.