ironjr / grokfast Goto Github PK
View Code? Open in Web Editor NEWOfficial repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
Home Page: https://arxiv.org/abs/2405.20233
License: MIT License
Official repository for the paper "Grokfast: Accelerated Grokking by Amplifying Slow Gradients"
Home Page: https://arxiv.org/abs/2405.20233
License: MIT License
Ill try my best, but thought to check if there is anyone else wanting to try this in context of transformers trainer.
Hey, everybody. Can you tell me what hyper parameters to start trying?
Great work! First study I see that focuses on making the phenomenon practical.
What is a good strategy for choosing the AdamW weight decay value for a new model and dataset? In the paper it seems there is a very large range of values used. Is there an approach you use to narrow down possible values that work instead of doing full hyperparameter turning (which is costly if one were to wait for Grokking to happen).
How will this interact with https://github.com/facebookresearch/schedule_free optimizer? Any gotchas to think about?
Hi! I've been playing around with the code for days and noticed an interesting phenomenon that I hope someone can help me understand: AdamW seems to be better than Grokfast + Adam in many cases.
Here are the details of my two experiments:
The results:
NO3 shows grokking as expected, but eventually generalizes:
results:
AdamW + small wd shows (not too significant) grokking as expected, but it finally reaches the expected validation acc (predicted in the original OpenAI grokking paper)
Again, Grokfast + Adam + small wd fails to learn the task:
So why is it the case that AdamW seems to be better than Grokfast + Adam in many cases?
I seem to keep running into an issue where the gradients (or rather, the grad norm) keeps getting larger and larger until eventually it becomes 'inf'
{'loss': 11.6752, 'grad_norm': 3789384056832.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2064384}
{'loss': 11.4293, 'grad_norm': 5675928780800.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2113536}
{'loss': 11.3192, 'grad_norm': 8501688532992.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2162688}
{'loss': 11.2486, 'grad_norm': 12734252974080.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2211840}
{'loss': 11.111, 'grad_norm': 19073997996032.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2260992}
{'loss': 11.2352, 'grad_norm': 28569986138112.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2310144}
{'loss': 11.3782, 'grad_norm': 42793548382208.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2359296}
{'loss': 11.2332, 'grad_norm': 64098310029312.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2408448}
{'loss': 11.3905, 'grad_norm': 96009648603136.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2457600}
{'loss': 11.2784, 'grad_norm': 143808037650432.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2506752}
{'loss': 11.3188, 'grad_norm': 215402843996160.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2555904}
{'loss': 11.4465, 'grad_norm': 322641097392128.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2605056}
{'loss': 11.3301, 'grad_norm': 483267975315456.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2654208}
{'loss': 11.354, 'grad_norm': 723862782214144.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2703360}
{'loss': 11.2944, 'grad_norm': 1084237784547328.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2752512}
{'loss': 11.1959, 'grad_norm': 1624025381994496.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2801664}
{'loss': 11.3966, 'grad_norm': 2432546264580096.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2850816}
{'loss': 11.2565, 'grad_norm': 3643589066227712.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2899968}
{'loss': 11.3079, 'grad_norm': 5457548907905024.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2949120}
{'loss': 11.2257, 'grad_norm': 8174589242769408.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2998272}
{'loss': 11.084, 'grad_norm': 1.22443075158016e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3047424}
{'loss': 11.2356, 'grad_norm': 1.834013527166157e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3096576}
{'loss': 11.0967, 'grad_norm': 2.747076973900595e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3145728}
{'loss': 11.3159, 'grad_norm': 4.114708807077069e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3194880}
{'loss': 11.2537, 'grad_norm': 6.163215792734208e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3244032}
{'loss': 11.1508, 'grad_norm': 9.231571782257869e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3293184}
{'loss': 11.2386, 'grad_norm': 1.382750805253161e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3342336}
{'loss': 11.2199, 'grad_norm': 2.0711531456181043e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3391488}
{'loss': 11.2959, 'grad_norm': 3.102276524535972e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3440640}
{'loss': 11.3998, 'grad_norm': 4.6467440153985024e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3489792}
{'loss': 11.2736, 'grad_norm': 6.960125757368566e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3538944}
In gradfilter_ema
, if I change:
p.grad.data = p.grad.data + grads[n] * lamb
to
p.grad.data = (p.grad.data + grads[n] * lamb) / (1 + lamb)
Then this solves the exploding gradients... however, I'm not sure of the implications of this change and wanted to reach out to hear your thoughts.
고맙습니다! 이 생각은 너무 멋있어요! 수고했어요!
I think the original article first discovered the grokking effect in transformers.
I have been experimenting with a seq2seq model, for language translation, and not seeing any behavior that would indicate any state transition on validation data.
Hi, I'm trying out Grokfast in a LLM scenario. Mixed precision training is a commonly-used technique to save GPU memory usage and speedup training. The following code is an example for FP16 training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
The question is where should I put grads = gradfilter_ema(model, grads)
? I tried to put this between scale
and unscale
, but it doesn't work, the loss scale just explodes.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.