Comments (3)
Hi @ratan, thank you for trying out turnkey! We currently have a limitation with respect to auto-regressive models: when you call model.generate(prompt)
, the model is invoked many times, each with a different input size. In turn, this triggers many builds, and therefore many folders with one ONNX file each.
I took the Mistral-7B-Instruct example and replaced the model.generate()
call with a model()
call, which generates a single token. This worked as expected and generated a single ONNX file. Here is my example code:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
result = model(encodeds)
Here is an additional example from turnkey's model corpus showing how to get a single ONNX file with a desired input shape: https://github.com/onnx/turnkeyml/blob/main/models/transformers/mistral_7b.py
@ratan does this workaround work for your use case?
cc @danielholanda for visibility
from turnkeyml.
Hi @jeremyfowers thanks for quick reply. I tried the suggestions you have mentioned and it worked!!
Replacing the model.generate()
call with a model()
call, i am able to generate the single ONNX file.
Below sample works for Mistral-7B-Instruct-v0.2
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# This is minimum setup required to convert LLM models (PyTorch) into ONNX Formats
# This converted ONNXR format can be viewed into Netron app
# Load the tokenizer and model
torch.set_default_device("cpu")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
text_prompt = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
# Tokenize the prompt and convert to PyTorch tensors
inputs = tokenizer.apply_chat_template(text_prompt, return_tensors="pt")
# Generate text using the model
# Replace model.generate() call with a model() to reduce the call to auto-regressive models
# https://github.com/onnx/turnkeyml/issues/95
outputs = model(inputs)
Below sample works for LLM360/CrystalChat, microsoft/phi-2 and adept/fuyu-8b models
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# This is minimum setup required to convert LLM models (PyTorch) into ONNX Formats
# This converted ONNXR format can be viewed into Netron app
# Load the tokenizer and model
torch.set_default_device("cpu")
model_id = "adept/fuyu-8b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
text_prompt = "<s> <|sys_start|> You are an AI assistant. You will be given a task. You must generate a detailed and long answer. <|sys_end|> <|im_start|> Write a python function that takes a list of integers and returns the squared sum of the list. <|im_end|>"
# Tokenize the prompt and convert to PyTorch tensors
inputs = tokenizer(text_prompt, return_tensors="pt")
# Generate text using the model
# Replace model.generate() call with a model() to reduce the call to auto-regressive models
# https://github.com/onnx/turnkeyml/issues/95
outputs = model(**inputs)
You may close this issue. Thanks again.
from turnkeyml.
I'm glad to have been able to help! Please reach out if you have any more questions :)
from turnkeyml.
Related Issues (20)
- Proposal: check model path in `--model_path` scripts
- Enhance conda creation methodology for onnxrt
- `stage_status = killed` has no `error_log` message HOT 1
- Benchmark Status Shows Error when a previous Build Stage failed. HOT 4
- WatchdogTimer doesn't kill child processes on timeout
- Unable to differentiate OS kills from timeouts
- Handle inconsistent opset placement in ONNX models
- Refresh: New command-line interface for TurnkeyML HOT 2
- Refresh: Eliminate the state dataclass and state.yaml file HOT 2
- Refresh: Transition benchmarking into a Stage
- Refresh: Delete the ONNX model analysis code
- Refresh: Remove support for keras and hummingbird HOT 1
- Refresh: Remove all benchmarking code
- Refresh: Post a deprecation notice on the main readme and tool
- Refresh: Update all documentation and examples with respect to the refresh
- Refresh: Make Discovery a Stage HOT 1
- Increase the default ONNX opset to 18
- Refresh: Move ONNX model validation into its own Stage
- Set numpy<2
- Refresh: Design Review
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from turnkeyml.