Giter Club home page Giter Club logo

grokfast's People

Contributors

d0rc avatar ironjr avatar majirky avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

grokfast's Issues

Choosing weight decay?

Great work! First study I see that focuses on making the phenomenon practical.

What is a good strategy for choosing the AdamW weight decay value for a new model and dataset? In the paper it seems there is a very large range of values used. Is there an approach you use to narrow down possible values that work instead of doing full hyperparameter turning (which is costly if one were to wait for Grokking to happen).

AdamW better than grokfast + Adam?

Hi! I've been playing around with the code for days and noticed an interesting phenomenon that I hope someone can help me understand: AdamW seems to be better than Grokfast + Adam in many cases.
Here are the details of my two experiments:

  1. I experimented with main.py on the following two setups: NO3 being AdamW with a small wd 0.01 and NO4 being ema grokfast with Adam and the same small wd.
Screenshot 2024-07-10 at 2 58 37 PM

The results:
NO3 shows grokking as expected, but eventually generalizes:
acc_NO3_none_wd10e-02_lrx4_optimizerAdamW_start_at1

However, NO4 fails:
acc_NO4_ema_a0990_l5_wd10e-02_lrx4_optimizerAdam_start_at1

  1. This time I changed the task to be learning x^2 + xy + y^2 instead of simple multiplication, and changed p from 97 to 113.
    Here are the setups:
Screenshot 2024-07-10 at 3 04 01 PM Screenshot 2024-07-10 at 3 04 20 PM

results:
AdamW + small wd shows (not too significant) grokking as expected, but it finally reaches the expected validation acc (predicted in the original OpenAI grokking paper)
acc_test_none_wd50e-03_lrx4

Again, Grokfast + Adam + small wd fails to learn the task:
acc_test_ema_a0980_l2_wd50e-03_lrx4

So why is it the case that AdamW seems to be better than Grokfast + Adam in many cases?

Exploding Gradients

I seem to keep running into an issue where the gradients (or rather, the grad norm) keeps getting larger and larger until eventually it becomes 'inf'

{'loss': 11.6752, 'grad_norm': 3789384056832.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2064384}
{'loss': 11.4293, 'grad_norm': 5675928780800.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2113536}
{'loss': 11.3192, 'grad_norm': 8501688532992.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2162688}
{'loss': 11.2486, 'grad_norm': 12734252974080.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2211840}
{'loss': 11.111, 'grad_norm': 19073997996032.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2260992}
{'loss': 11.2352, 'grad_norm': 28569986138112.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2310144}
{'loss': 11.3782, 'grad_norm': 42793548382208.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2359296}
{'loss': 11.2332, 'grad_norm': 64098310029312.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2408448}
{'loss': 11.3905, 'grad_norm': 96009648603136.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2457600}
{'loss': 11.2784, 'grad_norm': 143808037650432.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2506752}
{'loss': 11.3188, 'grad_norm': 215402843996160.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2555904}
{'loss': 11.4465, 'grad_norm': 322641097392128.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2605056}
{'loss': 11.3301, 'grad_norm': 483267975315456.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2654208}
{'loss': 11.354, 'grad_norm': 723862782214144.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2703360}
{'loss': 11.2944, 'grad_norm': 1084237784547328.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2752512}
{'loss': 11.1959, 'grad_norm': 1624025381994496.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2801664}
{'loss': 11.3966, 'grad_norm': 2432546264580096.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2850816}
{'loss': 11.2565, 'grad_norm': 3643589066227712.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2899968}
{'loss': 11.3079, 'grad_norm': 5457548907905024.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2949120}
{'loss': 11.2257, 'grad_norm': 8174589242769408.0, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 2998272}
{'loss': 11.084, 'grad_norm': 1.22443075158016e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3047424}
{'loss': 11.2356, 'grad_norm': 1.834013527166157e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3096576}
{'loss': 11.0967, 'grad_norm': 2.747076973900595e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3145728}
{'loss': 11.3159, 'grad_norm': 4.114708807077069e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3194880}
{'loss': 11.2537, 'grad_norm': 6.163215792734208e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3244032}
{'loss': 11.1508, 'grad_norm': 9.231571782257869e+16, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3293184}
{'loss': 11.2386, 'grad_norm': 1.382750805253161e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3342336}
{'loss': 11.2199, 'grad_norm': 2.0711531456181043e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3391488}
{'loss': 11.2959, 'grad_norm': 3.102276524535972e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3440640}
{'loss': 11.3998, 'grad_norm': 4.6467440153985024e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3489792}
{'loss': 11.2736, 'grad_norm': 6.960125757368566e+17, 'learning_rate': 2.5e-05, 'epoch': 0.0, 'num_input_tokens_seen': 3538944}

In gradfilter_ema, if I change:

p.grad.data = p.grad.data + grads[n] * lamb
to
p.grad.data = (p.grad.data + grads[n] * lamb) / (1 + lamb)

Then this solves the exploding gradients... however, I'm not sure of the implications of this change and wanted to reach out to hear your thoughts.

고맙습니다! 이 생각은 너무 멋있어요! 수고했어요!

Is this specific to transformers?

I think the original article first discovered the grokking effect in transformers.

I have been experimenting with a seq2seq model, for language translation, and not seeing any behavior that would indicate any state transition on validation data.

How to use Grokfast with FP16 mixed precision training?

Hi, I'm trying out Grokfast in a LLM scenario. Mixed precision training is a commonly-used technique to save GPU memory usage and speedup training. The following code is an example for FP16 training.

scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()

        # Unscales the gradients of optimizer's assigned params in-place
        scaler.unscale_(optimizer)

        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

The question is where should I put grads = gradfilter_ema(model, grads)? I tried to put this between scale and unscale, but it doesn't work, the loss scale just explodes.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.