Comments (10)
Sorry for the delay @DonggeunYu ! As you saw from your investigation, the issue is that we are using the weight of the embedding layer that is offloaded to cpu, meaning that the weights are on meta device. To fix the issue, we would need to use the layer instead of just getting the weights. Another solution would be to make sure that this specific layer stays on the cuda.
from transformers.
Hi @DonggeunYu, thanks for reporting!
We'll look into it. Out of interest, how did you discover this? Was is modifying the tests, or are the tests just an easy way to demonstrate this behaviour?
from transformers.
Easy way to demonstrate this behavior.
While using a private model, I discovered that there was a problem with nn.Embedding.
from transformers.
I may be wrong, as I still need to understand the transformers and accelerate code fully.
When offload is used, it becomes a meta device during the init process. The weight of nn.Embedding created in __init__
becomes the meta device. If i use nn.Embedding
callable, the pre_forward
hook of accelerate will match the device of args, kwrags, and embedding.
However, because the embedding weight is inserted into the forward of another module, it enters the pre_forward hook as a meta device.
To prove this, the log that the pre_forward
of accelerate hook.py outputs the module device and the device of the args.
Until the nn.embedding
weight in the problem, the module device is meta device, and the args device is cuda. If the nn.embedding
weight in the problem enters another module's args, the module device is meta and the args device is meta (embedding weight). An error occurs when performing meta to cuda using send_to_device (args, self.execution_device)
.
module.__class__.__name__, device of module, device of args
Linear [device(type='meta')] [device(type='cuda', index=0)]
LayerNorm [device(type='meta')] [device(type='cuda', index=0)]
Linear [device(type='meta')] [device(type='cuda', index=0)]
Linear [device(type='meta')] [device(type='cuda', index=0)]
LayerNorm [device(type='meta')] [device(type='cuda', index=0)]
Linear [device(type='meta')] [device(type='meta')]
def pre_forward of accelerate
nn.Embedding of transformers
from transformers.
@DonggeunYu Thanks for the update. Indeed, the structure of using the embedding weights rather than the layer in the forward pass is quite odd. cc @muellerzr who knows more about the pre_forward
hook of accelerate
from transformers.
@amyeroberts @muellerz
How is the progress?
from transformers.
cc @SunMarc
from transformers.
@SunMarc
I would appreciate it if you could check the issue.
from transformers.
I would appreciate it if you could check the issue. 🥲
from transformers.
@SunMarc
Is there any solution?
Is it possible to keep certain layers in cuda when using accelerate offload?
from transformers.
Related Issues (20)
- Test Results for sync branch upstream_sync
- Test Results for sync branch upstream_sync HOT 1
- Test Results for sync branch upstream_sync HOT 3
- Test Results for sync branch upstream_sync HOT 1
- Google RecurrentGemma Models don't work in Transformers 4.43 anymore HOT 2
- Mamba2 has different generations from input_ids and input_embeds
- Why the transcription of the audio by Whisper pipeline and processor are different?
- [BUG] ModelCard do not write true optimizer HOT 1
- Mamba2 Codestral generation example fails to load mismatching state dict HOT 5
- Autodesk AutoCAD Crack 2025 Keygen Activation | NEW HOT 1
- [spam message]
- when i update transformers from 4.38.1 to version==4.42.0, it happened that Failed to import trl.models.modeling_base because of the following error (look up to see its traceback): '>' not supported between instances of 'NoneType' and 'str' HOT 1
- Mamba-2 Exploding Gradients HOT 14
- [whisper] setting `prompt_condition_type="all-segments"` results in generation errors when `prompt_ids` is set HOT 1
- PreTrainedTokenizerFast `char_to_token` `token_to_char` not working as expected
- Llama3 Tokenizer Decode Removing Space Character HOT 1
- Gradient checkpointing warning HOT 1
- Gemma2 GGUF: `modeling_gguf_pytorch_utils.py: ValueError: Architecture gemma2 not supported` HOT 3
- [Bug] RT-DETR post-processing yields incorrect results when use_focal_loss=False HOT 1
- apply_rotary_pos_emb() Tensor size mismatch HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.