Comments (12)
That doesn't sounds the same. One could always completely turn off all dropouts, but normally we would want the checkpoint computed with the same dropout state as the first forward call so that the gradient with or without checkpoint are the same.
from transformers.jl.
These are good points @ToucheSir. I will come to this in two weeks timeframe, I am a bit busy now with academic stuff.
from transformers.jl.
Sounds good to have!
HF handle it in the forward method of hf-models (equiv. Layers.Transformer
). I'm not sure Checkedpointed
as AbstractTransformerBlock
is the best place to add the checkpoint functionality. Some alternative ideas I currently have in mind:
- Generalized
Checkedpointed{S} <: LayerStruct
and overloadCheckpointed{<:Transformer}
to add checkpoint per blocks. - Modify
Layers.applyblocks
to allow hooks and useZygote.checkpointed
as the hook function. - Similar to 2. but provide a
HookedTransformerBlock <: AbstractTransformerBlock
.
The wrapping function can be implemented with postwalk
like the Layers.set_dropout
.
from transformers.jl.
I will look at your suggestions. Checkpointed as a AbstractTransformerBlock
was quick and dirty trick. I like the postwalk
trick.
from transformers.jl.
One thing you'll want to think about is stateful layers like Dropout and BatchNorm which would not behave the same in subsequent calls. For the former I think some mechanism to snapshot RNG state would be required, and for the latter maybe an explicit overload?
from transformers.jl.
It seems the problem is that we cannot know if a Dropout or BatchNorm is executed under checkpointed environment?
from transformers.jl.
@ToucheSir I have not thought about this. Is there still switch to toggle train and test mode? That would effectively solve the problem.
from transformers.jl.
If the pullback is only called once, I believe BatchNorm and co should actually not require any special handling. Otherwise, the approach would be to traverse the model looking for these layers, saving their current train/test status, doing the checkpointing and then restoring the saved status.
As Peter notes, Dropout is trickier because you still need the RNG state around to create a mask. The most straightforward solution using struct Checkpointed
would be to recurse through the model looking for Dropout
layers and snapshotting their RNG state beforehand. Then that state can be restored whenever the checkpointing runs. I haven't quite thought about how this interacts with RNGs shared between layers (as is the default), but that should be solvable.
Medium-long term, we may want to consider a mechanism like https://github.com/vchuravy/ScopedValues.jl for exposing whether checkpointing is currently active in Flux itself. Then layers can query that info and change their behaviour accordingly without a wrapper.
from transformers.jl.
@ToucheSir I wonder if we could subtype the AContext in Zygote for a CheckpointedContext and overload the pullback behavior for dropout or so?
from transformers.jl.
Maybe, but generally we'd like to avoid coupling Flux to Zygote wherever possible (e.g. no custom pullbacks).
from transformers.jl.
I would say that only need to couple NNlib to Zygote since dropout is moved out from Flux.
from transformers.jl.
Yeah, NNlib has no dep (hard or weak) on Zygote right now and it'd be better to keep it that way. Porting Zygote.checkpoint
to use the ChainRules API shouldn't be an issue, just need to decide if it lives in Flux or NNlib.
from transformers.jl.
Related Issues (20)
- update NNlib and Flux compat HOT 9
- State of quantization HOT 3
- Dolly example no longer works ... HOT 19
- OWL-ViT HOT 1
- AMDGPU support HOT 1
- DistilBertModel support HOT 1
- Attempting to download CLIP yields UnderVarError `unk_token` not defined
- Performance issue HOT 1
- [Question] Possible to retrieve layer-wise activations? HOT 4
- Adding phi model HOT 5
- Please support Lux.jl HOT 7
- Example Code always produces Max Length Sequences
- how to download model weights on external drive
- Update to newer versions of dependencies
- Improve documentation and take inspiration from python package HOT 6
- please update compat bounds HOT 6
- Looking to update Transformers.jl and the associated modules HOT 1
- Storage of Downloaded Models from HuggingFace HOT 1
- Converting from integer-tokens to one-hot tokens gives different results. HOT 2
- Dependency resolution fails on Julia 1.10 HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.jl.