Giter Club home page Giter Club logo

Comments (10)

CarloLucibello avatar CarloLucibello commented on June 27, 2024

We fixed this
#1612
and then we unfixed it again
#1868
due to FluxML/Optimisers.jl#46 (comment)

There is some ambiguity in the paper. They call $\alpha$ the learning rate and $\eta_t$ some scheduled coefficient.
So assuming $\eta_t=1$ the current implementation seems correct.

On the other hand, the pytorch implementation seems equal to #1612, so I think we should fix AdamW again.

from flux.jl.

dpaetzel avatar dpaetzel commented on June 27, 2024

Thank you for unravelling that for me and sorry that I didn't notice those issues/PRs in the first place.

Short elaboration for future reference:

The paper on AdamW uncouples what it calls the “schedule multiplier $\eta_t$“ from the learning rate $\alpha$ (applied to the Adam update) and the regularization factor/weight decay $\lambda$. $\eta_t$ is applied as an additional factor to both of the other two factors.

Pytorch only exposes two parameters, $\gamma_\text{torch}$ (the learning rate) and $\lambda_\text{torch}$ (regularization factor/weight decay) with paper correspondence $\gamma_\text{torch} = \eta_t \alpha$ and $\gamma_\text{torch} \lambda_\text{torch} = \eta_t \lambda$. The implementation in #1612 hardcodes $\alpha$ to 1 which is exactly the same parametrization.

I can't quite tell how important the additional control of an uncoupled $\alpha$ would be; probably not important enough to diverge from the pytorch implementation, I guess?

from flux.jl.

CarloLucibello avatar CarloLucibello commented on June 27, 2024

We should adhere to pytorch's implementation for sure. Would you mind filing PRs here and in Optimisers.jl?

from flux.jl.

ToucheSir avatar ToucheSir commented on June 27, 2024

I don't have time to comment on this in detail now (will do so later), but the decision to diverge from PyTorch was not made lightly. IIRC it was something about how their handling and interpretation of the learning rate was unintuitive and would trip up people moving from other optimizers -> AdamW. I also didn't find their justification particularly compelling.

from flux.jl.

ToucheSir avatar ToucheSir commented on June 27, 2024

Ok, I did some more digging into why PyTorch decided to couple the learning rate and weight decay coefficient for their AdamW implementation. My best guess is that this comment on one of the AdamW PRs triggered changes which cascaded all the way to the ultimate AdamW PR. I don't find the point super compelling here because Flux lacks a Adam + coupled L2 norm constructor unlike PyTorch. Moreover, changing the calculation would be a breaking change for Flux and Optimisers.jl.

Now for an argument on semantics and usability. I agree that separate scheduling alone is not enough to justify a separate learning rate and weight decay rate. The problem lies more with tweaking hyperparameters. The AdamW paper makes a big point about being able to control both independently. With both coupled as PyTorch does, you have to always remember to tweak the weight decay every time you tweak the learning rate, otherwise you will be increasing/decreasing both simultaneously. We may even have public examples of people not realizing this, e.g. fastai/fastai#1806 (funnily enough, FastAI's AdamW used to not couple the two hyperparams).

There's also a practical concern if we do introduce hyperparam scheduling (i.e. controlling $\eta_t$ using ParameterSchedulers.jl). The previous implementation in #1612 chained together two rules with a learning rate (eta) field, but one of them must remain fixed at eta = 1 in order for the algorithm to be correct. Optimisers.adjust! will by default adjust both learning rates, and trying to get it to only adjust one would require a good amount more code.

As such, I think the best path forward would be to add a keyword arg to the AdamW constructor. Call it couple_lr or something, and have it return something closer to #1612 if couple_lr=true. As I noted, we'd likely also need to add a wrapper type for AdamW instead of using OptimiserChain directly.

from flux.jl.

CarloLucibello avatar CarloLucibello commented on June 27, 2024

I have to think a bit about it. Another datapoint is that also optax couples the two
https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw

from flux.jl.

ToucheSir avatar ToucheSir commented on June 27, 2024

I actually opened an issue on the Optax repo about this, and their more or less said they wanted to copy PyTorch...

from flux.jl.

CarloLucibello avatar CarloLucibello commented on June 27, 2024

There's also a practical concern if we do introduce hyperparam scheduling

I think we should simply implement AdamW by copy-pasting the code from Adam.

We can add the couple_lr thing. The default should be couple_lr=true though, we should do what everybody else is doing. AdamW has become very popular in recent years and we want experiments in papers to be reproducible, finding an optimizer's flag as the source of divergence would be a very frustrating experience.

from flux.jl.

ToucheSir avatar ToucheSir commented on June 27, 2024

No objections, we'd just have to make a breaking release with it. Anything else we'd want to get in said release?

from flux.jl.

dpaetzel avatar dpaetzel commented on June 27, 2024

Thank you for the extended discussion!

Just to make sure I understand correctly (I'll try to find time to submit a PR):

couple_lr == false would mean to use the parametrization from the AdamW paper, i.e.

$$ \begin{align*} \theta_t \leftarrow \theta_{t-1} - \eta (\alpha \ A + \lambda \theta_{t-1}) \end{align*} $$

(where $A$ is the update coming from Adam and we expose to the user $\eta$, $\alpha$, $\lambda$) whereas couple_lr == true would mean to use this parametrization:

$$ \begin{align*} \theta_{t-1} - \eta (\alpha \ A + \lambda \theta_{t-1}) & = \theta_{t-1} - \eta \alpha \ A + \eta \lambda \theta_{t-1} \\ & = \theta_{t-1} - \gamma_\text{torch} \ A + \gamma_\text{torch} \lambda_\text{torch} \theta_{t-1} \\ & = \theta_{t-1} - \gamma_\text{torch} (1 A + \lambda_\text{torch} \theta_{t-1}) \\ \end{align*} $$

where we expose to the user $\eta = \gamma_\text{torch}$ and $\lambda = \lambda_\text{torch}$ (i.e. changing $\alpha$ doesn't do anything as long as couple_lr == true).

from flux.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.