infini-ai-lab / sequoia Goto Github PK
View Code? Open in Web Editor NEWscalable and robust tree-based speculative decoding algorithm
scalable and robust tree-based speculative decoding algorithm
Hi,
I remember the support on vLLM was on your TODOs. Have you achieved it now? Was the main challenge in this direction that the batch size > 1 tree verification is hard to made efficient? Thanks!
Hi,
If I understand the tree_search algorithm right, the dynamic programming process should be able to find the optimal number of generated tokens according to the acceptance-rate-vector. Also, given the acceptance-rate-vector and the candidate tree, the number of generated tokens can also be computed. But this is just theory. In the paper, the number of generated tokens are measured with experimenting runs. I'm wondering if these experimental-measured generated token numbers agree with the theoretical optimal generated token number?
I was trying to verify it, but in the repo, there is only tree_maps, while the acceptance vectors are missing. I'm wondering if you have considered this estimation before. Or, could you share the acceptance vectors, so that, along with the corresponding trees, I can quickly verify it?
Thanks!
The dataset loading code is taking too long. It downloads whole huge datasets (70G wiki, etc) to use just a handful of examples. setting split="train[0:2000]")
is not helping since slicing happens only after full download
Suggestions:
allenai/c4
: load_dataset("allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train")
load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
Hi,
Thanks for the great work!
I'm wondering if there is any benchmark that compares Sequoia against vanilla speculative decoding?
Thanks for your good work.
Thanks.
Best regard.
tried to test the code with run_A100.sh
script but got this error:
$:~/sequoia/tests$ bash run_A100.sh
...
Traceback (most recent call last):
File "/extra_disk_1/optimus/sequoia/tests/testbed.py", line 293, in <module>
simulation_fast(target_model=target_model, draft_model=draft_model, dataloader=dataloader, T=args.T, top_p=args.P,
File "/extra_disk_1/optimus/sequoia/tests/testbed.py", line 68, in simulation_fast
spectree = SpecTree(prefix=input_ids.squeeze(0), device='cuda:0', temperature=T,
File "/extra_disk_1/optimus/sequoia/tests/../Tree/SpecTree.py", line 68, in __init__
draft_model_outputs = self.draft_model_engine.inference(input_ids=self.tokens[:self.num_nodes].unsqueeze(0),
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/extra_disk_1/optimus/sequoia/tests/../Engine/Engine.py", line 242, in inference
return self.engine.model_run(input_ids=input_ids, storage_ids=storage_ids,
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/extra_disk_1/optimus/sequoia/tests/../Engine/Engine.py", line 38, in model_run
logits = self.model(input_ids=input_ids,
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/extra_disk_1/optimus/sequoia/tests/../Engine/Llama_model.py", line 201, in forward
outputs = self.model(
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/extra_disk_1/optimus/sequoia/tests/../Engine/Llama_model.py", line 59, in forward
layer_outputs = decoder_layer(
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/extra_disk_1/optimus/sequoia/tests/../Engine/Llama_modules.py", line 334, in forward
hidden_states = self.self_attn(
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/optimus/conda/envs/py9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/extra_disk_1/optimus/sequoia/tests/../Engine/Llama_modules.py", line 127, in forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: p.attn_bias_ptr is not correctly aligned
my lib versions:
$:~/sequoia/tests$ pip list | grep -e transformers -e torch -e accelerate
accelerate 0.26.1
torch 2.1.0
torchaudio 0.13.1
torchvision 0.14.1
transformers 4.37.2
Hey @dreaming-panda,
This looks really interesting. I wondered if you would be interested to show an integration with Lit-GPT: https://github.com/Lightning-AI/litgpt
Best,
T.C
Description:
When I tried to reproduce the paper result by README, an exception raised:
return forward_call(*args, **kwargs)
File "/data0/xiac/RLHF/Prelim/Sequoia/tests/../Engine/Llama_model.py", line 59, in forward
layer_outputs = decoder_layer(
File "/home/xiac/.conda/envs/rlhf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xiac/.conda/envs/rlhf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data0/xiac/RLHF/Prelim/Sequoia/tests/../Engine/Llama_modules.py", line 334, in forward
hidden_states = self.self_attn(
File "/home/xiac/.conda/envs/rlhf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xiac/.conda/envs/rlhf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/data0/xiac/RLHF/Prelim/Sequoia/tests/../Engine/Llama_modules.py", line 118, in forward
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
File "/home/xiac/.conda/envs/rlhf/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 207, in apply_rotary_pos_emb
q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (12) must match the size of tensor b (384) at non-singleton dimension 1
I tracked the function calling and enabled the 'debug' flag in engine.model_run. When I tried it again, the assertion failed:
Traceback (most recent call last):
File "/data0/xiac/RLHF/Prelim/Sequoia/tests/testbed.py", line 268, in <module>
draft_model.initialize_cuda_graph(graph_capture_list)
File "/home/xiac/.conda/envs/rlhf/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data0/xiac/RLHF/Prelim/Sequoia/tests/../Engine/Engine.py", line 189, in initialize_cuda_graph
self.callables[decoding_seqlen] = capture_graph(
File "/data0/xiac/RLHF/Prelim/Sequoia/tests/../Engine/Engine.py", line 141, in capture_graph
static_logits = engine.model_run(
File "/home/xiac/.conda/envs/rlhf/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data0/xiac/RLHF/Prelim/Sequoia/tests/../Engine/Engine.py", line 34, in model_run
assert attention_mask.shape[0] == input_length
AssertionError
I checked the code and found a suspicious line in capture_graph
:
static_attn_mask = torch.full((decoding_seqlen, engine.max_length), 0, dtype=dtype, device=device)
static_attn_mask = static_attn_mask[None, None, :, :]
the last line changes static_attn_mask into shape of (1,1, x, y), which certainly fails the check.
Sorry for asking a possibly obvious question but it would be better if the documentation makes this clear.
Hi,
I was trying to reproduce the numbers in the paper, but with the demo-config.json
, plus the acceptance vector in the repo or the acceptance vector I tested myself, the generated trees are all very small and somewaht fixed:
0 _ 1 _ 3
\_ 2
or
0 _ 1 _2 _3
But on the other hand, the growmaps in the two folders are generally very large, typically of size 128, 64, 32. Do you know what the possible reason is that the tree I generated is small and how to reproduce the growmaps in those two folders?
Thank you!
Here, the max_branch
equals to K + 1, K refers to Algorithm 5 in your paper. The K + 1 dimension of p
represents the percentage of accept none from the draft model according to your code in test_accept.py. So there is gap between your code and Algorithm 5? Correct me if I misunderstood anything!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.