Giter Club home page Giter Club logo

wasam's Introduction

Weight-Averaged Sharpness-Aware Minimization (WASAM)

Alt Text

A minimum working example for incorporating WASAM in an image classification pipeline implemented in PyTorch.

Usage

Simple option: Single closure-based step function with single set of averaged weights

from wasam import WASAM
...

model = YourModel()
base_optimizer = torch.optim.SGD(lr=0.1, momentum=0.9)  # define an optimizer for the "sharpness-aware" update
optimizer = WASAM(model.parameters(), base_optimizer, rho=0.05, lr=0.1, momentum=0.9)
max_epochs, swa_start_coeff = 200, 0.75
swa_start_epoch = int(max_epochs * swa_start_coeff)
...
for epoch in range(1, max_epochs+1):
    # train one epoch
    for input, output in loader:
      def closure():
        loss = loss_function(output, model(input))
        loss.backward()
        return loss
    
      loss = loss_function(output, model(input))
      loss.backward()
      optimizer.step(closure) # performs model update and zeros gradients internally
    # during end of training, average weights 
    if epoch >= swa_start_epoch:
        optimizer.update_swa()
    # before model evaluation, swap weights with averaged weights
    optimizer.swap_swa_sgd()
    evaluate_model(model)
    # after model evaluation, swap them back (if training continues)
    optimizer.swap_swa_sgd()
...

Advanced option: Two forward and backward passes in training loop and multiple averaged weights

This option is slightly more complicated but enables higher flexibility.

There are two differences:

  1. We perform both forward and backward passes directly in the training loop
  2. We store and update multiple averaged models starting at different times
from wasam import WASAM
from swa_utils import MultipleSWAModels
...
device = torch.device("cuda:0")
model = YourModel()
base_optimizer = torch.optim.SGD(lr=0.1, momentum=0.9)  # define an optimizer for the "sharpness-aware" update
optimizer = WASAM(model.parameters(), base_optimizer, rho=0.05)
max_epochs, swa_start_coeff = 200, 0.75
swa_starts = [0.5, 0.6, 0.75, 0.9]
swa_models = MultipleSWAModels(model, device, max_epochs, swa_starts)
...
for epoch in range(1, max_epochs+1):
    # train one epoch
    for input, output in loader:
        # first forward-backward pass
        loss = loss_function(output, model(input))  # use this loss for any training statistics
        loss.backward()
        optimizer.first_step(zero_grad=True)
        # second forward-backward pass
        loss_function(output, model(input)).backward()  # make sure to do a full forward pass
        optimizer.second_step(zero_grad=True)
    # average weights 
    swa_models.update_parameters(model, epoch) # checks if epoch >= swa_start internally
    # for model evaluation, you can loop over all averaged models
    for model_dict in swa_models.models:
        swa_model, swa_start = model_dict["model"], model_dict["start"]
        if epoch >= swa_start:
            evaluate_model(swa_model)
...

Tips

BatchNorm Layers

If your model possesses BatchNorm layers, you have to update the activation statistics of the averaged model, before you can use it. Here is a modified version of the simple option example.

from wasam import WASAM
...

model = YourModel()
base_optimizer = torch.optim.SGD(lr=0.1, momentum=0.9)  # define an optimizer for the "sharpness-aware" update
optimizer = WASAM(model.parameters(), base_optimizer, rho=0.05, lr=0.1, momentum=0.9)
max_epochs, swa_start_coeff = 200, 0.75
swa_start_epoch = int(max_epochs * swa_start_coeff)
...
for epoch in range(1, max_epochs + 1):
    train_model(loader, model, optimizer)
    # during end of training, average weights 
    if epoch >= swa_start_epoch:
        optimizer.update_swa()
    # before model evaluation, swap weights with averaged weights
    optimizer.swap_swa_sgd()
    optimizer.bn_update(loader, model) # <-------------- Update batchnorm statistics
    evaluate_model(model)
    # after model evaluation, swap them back (if training continues)
    optimizer.swap_swa_sgd()
...

Similarly, in the advanced option, one can update them as follows:

# for model evaluation, you can loop over all averaged models
for model_dict in swa_models.models:
    swa_model, swa_start = model_dict["model"], model_dict["start"]
    if epoch >= swa_start:
        optimizer.bn_update(loader, swa_model) # <-------------- Update batchnorm statistics
        evaluate_model(swa_model)

Installation

Install packages by

pip install -r requirements.txt

Then, you can run

cd example
python main.py

Citation

If you find this repository useful, please consider citing the paper.

@inproceedings{
    kaddour2022when,
    title={When Do Flat Minima Optimizers Work?},
    author={Jean Kaddour and Linqing Liu and Ricardo Silva and Matt Kusner},
    booktitle={Advances in Neural Information Processing Systems},
    editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
    year={2022},
    url={https://openreview.net/forum?id=vDeh2yxTvuh}
}

Acknowledgements

This codebase builds on other repositories:

Thanks a lot to the authors of these!

wasam's People

Contributors

jeankaddour avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.