Giter Club home page Giter Club logo

pytorch_ema's Introduction

pytorch_ema

A small library for computing exponential moving averages of model parameters.

This library was originally written for personal use. Nevertheless, if you run into issues or have suggestions for improvement, feel free to open either a new issue or pull request.

Installation

For the stable version from PyPI:

pip install torch-ema

For the latest GitHub version:

pip install -U git+https://github.com/fadel/pytorch_ema

Usage

Example

import torch
import torch.nn.functional as F

from torch_ema import ExponentialMovingAverage

torch.manual_seed(0)
x_train = torch.rand((100, 10))
y_train = torch.rand(100).round().long()
x_val = torch.rand((100, 10))
y_val = torch.rand(100).round().long()
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.995)

# Train for a few epochs
model.train()
for _ in range(20):
    logits = model(x_train)
    loss = F.cross_entropy(logits, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # Update the moving average with the new parameters from the last optimizer step
    ema.update()

# Validation: original
model.eval()
logits = model(x_val)
loss = F.cross_entropy(logits, y_val)
print(loss.item())

# Validation: with EMA
# the .average_parameters() context manager
# (1) saves original parameters before replacing with EMA version
# (2) copies EMA parameters to model
# (3) after exiting the `with`, restore original parameters to resume training later
with ema.average_parameters():
    logits = model(x_val)
    loss = F.cross_entropy(logits, y_val)
    print(loss.item())

Manual validation mode

While the average_parameters() context manager is convinient, you can also manually execute the same series of operations:

ema.store()
ema.copy_to()
# ...
ema.restore()

Custom parameters

By default the methods of ExponentialMovingAverage act on the model parameters the object was constructed with, but any compatable iterable of parameters can be passed to any method (such as store(), copy_to(), update(), restore(), and average_parameters()):

model = torch.nn.Linear(10, 2)
model2 = torch.nn.Linear(10, 2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
# train
# calling `ema.update()` will use `model.parameters()`
ema.copy_to(model2)
# model2 now contains the averaged weights

Resuming training

Like a PyTorch optimizer, ExponentialMovingAverage objects have state_dict()/load_state_dict() methods to allow pausing, serializing, and restarting training without loosing shadow parameters, stored parameters, or the update count.

GPU/device support

ExponentialMovingAverage objects have a .to() function (like torch.Tensor) that can move the object's internal state to a different device or floating-point dtype.

For more details on individual methods, please check the docstrings.

pytorch_ema's People

Contributors

linux-cpp-lisp avatar fadel avatar zehui-lin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.