Comments (10)
A lot has changed since this issue, and I'd like to summarize:
There are two ways to consider scaling architectures
- Split layers onto devices manually
- Split all layers equally onto devices
1 is extremely difficult to get right when architectures are large and complicated and to maintain effeciency. 2 which in recent years via DeepSpeed and now FairScale are more prominent, offer an elegant way to scale model architecture with minimal annotation.
Fully Sharded Data Parallel has been merged, and offers the ability to leverage 2 and in most cases, solve the underlying scaling issue. I have a PR for FSDP documentation #7791 which will hopefully explain more as to how this works :) Once merged, we should be able to close this!
EDIT code example:
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from fairscale.nn import wrap
class MyModule(LightningModule):
def configure_sharded_model(self):
# layers will be sharded across all devices
model_a = wrap(SomeModel())
layer_1 = wrap(Linear(...))
layer2 = wrap(Linear(...))
self.model = nn.Sequential(model_a, layer_1, layer_2)
def forward(x):
x = self.model(x)
return x
model = MyModule()
trainer = Trainer(gpus=4, plugins='fsdp')
trainer.fit(model)
from lightning.
Could use something similar to this to approximate mem usage per layer/module and then balance accordingly.
from lightning.
that’s helpful. You also beed to account for the size of the inout and output including taking batch size into account. sometimes the problem is that the layer output blows up the ram. so we’d need to probably try catch a few passes through each block and calculate its full memory requirement.
the memory requirement is weights + input + output. and gpu 0 has added overhead of optimizer which in case of adam has the grads
from lightning.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
from lightning.
@SeanNaren, Fairscale should partially provide this feature with OSS
right ?
from lightning.
@tchaton, not exactly! This is covered by #4443 which introduces the pipe accelerator (allows you to split a model across GPUs). The self balancing part isn't easy, but can be done via functions like this in fairscale:
https://github.com/facebookresearch/fairscale/blob/7c5203eb772d7c67e45ed6ff6b66579b8e5cbc6c/fairscale/nn/pipe/balance/__init__.py#L100
I've been looking into the pipe accelerator but there are a few nice changes coming up with this PR: facebookresearch/fairscale#156
Would be nice to get them in first before adding the plugin/accelerator for this :)
from lightning.
Has there been any progress on this feature? I see that there’s a Beta section on the documentation here: https://pytorch-lightning.readthedocs.io/en/latest/multi_gpu.html#model-parallelism-beta but I don’t know if this works with DDP
from lightning.
Any updates on this issues ?
from lightning.
Hey @tchaton, a small update :)
It's been a while and supporting transparent self-balancing architectures with no friction hasn't been solved, and that's primarily due to the difficulty of engineering such balancing.
In most cases this requires a lot of engineering effort, and even our pipe implementation is very specific/provides little flexibility when using.
The current roadmap tends to Fully Sharded Data Parallel replacing the need for self-balancing, by allowing the user to annotate layers (or automate annotation) with FSDP, signalling that these layers should be loaded into memory, do any necessary computation and be de-allocated ASAP. This allows to scale the model size drastically and trade off time. If anyone is interested, look at our initial integration which we're working with the FairScale team to prove out and ensure we have rigid tests/benchmarks in place #6152
from lightning.
Closing this super old issue. strategy="fsdp"
is your friend.
You can find guides at https://lightning.ai/docs/pytorch/latest/advanced/model_parallel.html for the Trainer and https://lightning.ai/docs/fabric/latest/advanced/model_parallel/fsdp.html for Fabric
from lightning.
Related Issues (20)
- logged images are not showing up in tensorboard images.
- Enable batch size finder for distributed strategies HOT 1
- Support wandb_logger.watch() when using LightningCLI
- The packages such as libraries and models are not loading from files
- Please make it simple!
- LOG issue HOT 1
- Multi-gpu training is much lower than single gpu (due to additional processes?)
- Missing documentation for the `log_weight_decay` argument in `lightning.pytorch.callbacks.LearningRateMonitor`
- parsing issue with `save_last` parameter of `ModelCheckpoint`
- Construct objects from yaml by classmethod
- FSDP Strategy checkpoint loading
- Current FSDPPrecision does not support custom scaler for 16-mixed precision
- Differentiate testing multiple sets/models when logging
- Issue in Manual optimisation, during self.manual_backward call HOT 1
- Existing metric keys not moved to device after LearningRateFinder
- Checkpoint every_n_steps reruns epoch on restore HOT 3
- Metrics logged by self.log and metric.compute() are different HOT 1
- Multi-node Training with DDP stuck at "Initialize distributed..." on SLURM cluster HOT 3
- Full validation after first microbatch when training after LearningRateFinder
- Add a warning when some of the modules are in eval mode before the training stage
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightning.