Comments (13)
Great RFC! Have you had the chance to verify this with PP? #4412
Yes, definitely. Actually ideally we should design these together. One possibility is to merge #4412 in two parts, one for changes to worker-local execution and then a second PR to add the control plane to glue different workers together. That way, we can follow the interfaces that I'm proposing here and keep the control plane separate from model execution.
Concretely, the first PR would then contain these changes:
- support multiple cache engines in one Worker
- support workers executing one shard of a model
The proposed Worker.execute_model_local
function would directly return the hidden states from the model shard instead of using send/recv calls inside the model definition. We would then glue together all of the p2p connections in the Worker
class as a separate PR.
This way, it will be easier to try different control plane methods. We can use the approach you have in #4412. Another option is a new backend in Ray that that we've been developing to improve performance for static task graphs. I wrote an integration for this based off of an earlier version of #4412: graph definition and how to call it.
from vllm.
Does it make sense:
class WorkerBase:
class WorkerInput:
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_worker_input(self, seq_group_metadata_list) -> WorkerInput:
pass
def execute_worker(self, input_data: WorkerInput):
pass
def execute_model(self, seq_group_metadata_list):
if self.is_driver_worker:
worker_input = self.prepare_worker_input(seq_group_metadata_list)
self.execute_worker(worker_input)
model_input = self.model_runner.prepare_model_input(seq_group_metadata_list)
data_to_broadcast = worker_input.get_broadcastable_data()
data_to_broadcast.update(model_input.get_broadcastable_data())
broadcast_tensor_dict(data_to_broadcast, src=0)
else:
data_to_broadcast = broadcast_tensor_dict(src=0)
worker_input = self.WorkerInput.from_broadcast_data(data_to_broadcast)
self.execute_worker(worker_input)
model_input = self.ModelRunnerInput.from_broadcast_data(data_to_broadcast)
self.model_runner.execute_model(model_input)
class ModelRunnerBase:
class ModelRunnerInput:
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_model_input(self, seq_group_metadata_list) -> ModelRunnerInput:
pass
def execute_model(self, input_data: ModelRunnerInput):
pass
Then, say we want to add GPU worker and GPUModelRunner:
class GPUWorker(WorkerBase):
class WorkerInput(WorkerBase.WorkerInput):
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_worker_input(self, seq_group_metadata_list) -> WorkerInput:
pass
def execute_worker(self, input_data: WorkerInput):
pass
class GPUModelRunner(ModelRunnerBase):
class ModelRunnerInput(ModelRunnerBase.ModelRunnerInput):
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_model_input(self, seq_group_metadata_list) -> ModelRunnerInput:
pass
def execute_model(self, input_data: ModelRunnerInput):
pass
The control-plane communication is centralized in WorkerBase.execute_model
.
from vllm.
Yes, will reach out to find some time to chat!
We would have to bubble up hidden_states and residuals through the model definition to
ModelRunner
and then toWorker
. If we calculatedSamplingOutputs
inModelRunner
the way it's done now in this PR as well we would also need to bubble that up toWorker
or calculateSamplingOutputs
inWorker
instead. I'm not opposed to doing it this way, but this is an example of one of those details it would be good to discuss further and brainstorm if we can do it in a more elegant fashion.
For this, I don't think we need to bubble up the sampling procedure to Worker
. Let's discuss more offline, but I'm imagining the following interface, where ModelRunner.execute_model
would return either a SamplingOutput
or anIntermediateOutput
, which can be a dict of tensors (or a separate class). Then the Worker
can have the logic to either return the output directly or send it to a different worker.
IntermediateOutput = Dict[str, torch.Tensor]
class ModelRunner:
def execute_model(self, model_input: ModelInput) -> Union[List[SamplingOutput], IntermediateOutput]:
pass
from vllm.
Great RFC! Have you had the chance to verify this with PP? #4412
from vllm.
Yes agreed, we should chat more about this, what you're suggesting make sense to me. There are 3-4 optimizations that I know of that we can do on top of #4412, but my current plan is to have #4412 merged as a base on top of the current logic in order to have basic PP fully functional as soon as possible before moving on to performance refactoring.
from vllm.
Yes agreed, we should chat more about this, what you're suggesting make sense to me. There are 3-4 optimizations that I know of that we can do on top of #4412, but my current plan is to have #4412 merged as a base on top of the current logic in order to have basic PP fully functional as soon as possible before moving on to performance refactoring.
Great! Can you say a bit more about what optimizations you were thinking of? The reason I suggested splitting #4412 is that I think it will be easier to introduce some optimizations for PP if we can merge in this refactor first.
from vllm.
@youkaichao that sounds good to me. I can make those changes in #5408.
from vllm.
Great! Can you say a bit more about what optimizations you were thinking of? The reason I suggested splitting #4412 is that I think it will be easier to introduce some optimizations for PP if we can merge in this refactor first.
Off the top of my head, the following optimizations are possible:
- Merge cache engines and scheduler to take advantage of prefix caching.
- Reduce RPC overhead by sending/recving metadata instead of having multiple driver workers.
- Reduce CUDAGraph overhead with PP to make CUDAGraph more viable.
In general, I agree that it would potentially make optimizations easier (particularly 2 above). My concern is that we are prioritizing optimization prematurely here and further delaying the already delayed PP feature. IMO we should have the functionality available fully first, and then implement this refactoring on top of that.
cc: @zhuohan123 - this is what we chatted about last week
from vllm.
Sounds good. Yes, I also don't want to block PP; I just think that it may actually be faster long-term to merge a version that's compatible with this refactor.
If you do not want to split #4412, at least I think we need to move the p2p communication out of the model definitions and into the Worker
class. That will make the merge with this refactor smoother, plus it makes #4412 unit-testable.
from vllm.
Sounds good. Yes, I also don't want to block PP; I just think that it may actually be faster long-term to merge a version that's compatible with this refactor.
Makes sense! I'm not opposed to splitting #4412, but just think that if we are going to do so it's a good idea to coordinate more closely on the details so we have a good plan to get everything in with low friction. We can meet next week to talk in more detail if you are free.
If you do not want to split #4412, at least I think we need to move the p2p communication out of the model definitions and into the Worker class. That will make the merge with this refactor smoother, plus it makes #4412 unit-testable.
To give more context, I had thought about this initially but decided against it at first since different models send/recv different numbers of tensors. For example, gpt2
sends only the hidden states while LLaMa
sends both hidden states and residuals. We would need to modify each model file to return a list of tensors to the ModelRunner
layer as opposed to just hidden_states
which I found more intrusive than what I have now which is localized to each model definition - and at the time there wasn't a good reason to do so. It makes sense to me to make this change given this new refactor.
Given what you proposed above it also gets a little more complicated in the context of this PR if I understand correctly:
The proposed
Worker.execute_model_local
function would directly return the hidden states from the model shard instead of using send/recv calls inside the model definition. We would then glue together all of the p2p connections in the Worker class as a separate PR.
We would have to bubble up hidden_states and residuals through the model definition to ModelRunner
and then to Worker
. If we calculated SamplingOutputs
in ModelRunner
the way it's done now in this PR as well we would also need to bubble that up to Worker
or calculate SamplingOutputs
in Worker
instead. I'm not opposed to doing it this way, but this is an example of one of those details it would be good to discuss further and brainstorm if we can do it in a more elegant fashion.
from vllm.
I think this RFC makes a lot of sense. It's a great idea to put all of the communication logic in the same place. I previously misunderstood this as a bigger scope change that changes how we do control-plane communication. Some smaller questions about this RFC:
- What is the difference and relationship between
prepare_model_input_local
andprepare_model_input
? Can we just keepprepare_model_input
? - What is the relationship between this RFC and the ray DAG change you are proposing?
from vllm.
I think this RFC makes a lot of sense. It's a great idea to put all of the communication logic in the same place. I previously misunderstood this as a bigger scope change that changes how we do control-plane communication. Some smaller questions about this RFC:
Yes! For now this is just proposing a refactor that would allow changing the control plane more easily. This RFC doesn't propose any behavior changes (except for squashing some broadcasts).
* What is the difference and relationship between `prepare_model_input_local ` and `prepare_model_input`? Can we just keep `prepare_model_input`?
Ah, yes, I actually updated the RFC since @youkaichao suggested something similar. Now there is only prepare_model_input
.
* What is the relationship between this RFC and the ray DAG change you are proposing?
This would make it easier to integrate Ray DAG / any other control plane method that broadcasts the ExecuteModelRequest to all workers instead of broadcasting the ModelInput. If we want to support Ray DAG right now, we need to update the control flow in the worker and model runner, e.g., to skip the tensor broadcasts. We can do that for the main Worker codepath, but it makes the code pretty messy and we'd have to do the same thing for every other worker and model runner that we want to support. With the new APIs, we can just override Worker.execute_model
and only call the worker-local methods.
class RayDAGWorkerBase(LocalOrDistributedWorkerBase):
def execute_model(self, execute_model_req: ExecuteModelRequest):
worker_input = self.prepare_worker_input(execute_model_req)
self.execute_worker(worker_input)
model_input = self.model_runner.prepare_model_input(execute_model_req)
return self.model_runner.execute_model(model_input)
from vllm.
Merged
from vllm.
Related Issues (20)
- [Bug]: CUDA Error when print HOT 2
- Fp8 support for mi300x HOT 5
- [Bug]: SpecDecode AsyncMetricsCollector _last_metrics_collect_time is never reset
- [Feature]: LLM2Vec (Fine-Tuned Embeddings) Support
- [Bug]: Crash possible with Pipeline Parallel when aborting requests HOT 1
- [Bug]: Intel GPU Test failing in CI
- [Feature]: Implementation of Sliding Window Attention for Full Context Support with Gemma-2
- [Usage]:Can vllm use a method similar to device_map in transformers ? HOT 9
- [Bug]: Error when loading mistral and gemma model using VLLM docker HOT 9
- [Feature]: MultiModal LLM with vector API HOT 1
- Increase supported token window for using LoRA Adapter with mistralai/Mistral-Nemo-Instruct-2407 HOT 11
- [Bug]: Phi-3-mini does not work when using Ray
- [Performance]: Multi-node Pipeline Parallel double bandwidth, no change in performance HOT 4
- [Bug]: No available block found in 60 second in shm HOT 2
- [Feature]: 4D Attention Mask HOT 2
- [Feature]: Support DeepSeek-V2 - MI300x
- [Performance]: Flashinfer backend's improvement is marginal compared to FlashAttention backend for long context Qwen2-72b-instruct-128k
- [Feature]: support reward model API
- [Installation]: ERROR: No matching distribution found for torch==2.3.1
- [Performance]: Llava runs with small batch size and # of GPU blocks HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vllm.