Giter Club home page Giter Club logo

pytorch-checkpoint's Introduction

pytorch-checkpoint

Gradient checkpointing is a technique to reduce GPU memory cost.

Official implementation

There exists a PyTorch implementaion in the official repo. However, it is extremely slow with multiple GPUs.

This implementation

This repo contains a PyTorch implemention that can work on multiple GPUs.

Main results

Method # GPU Batch Memory Time
Naive 2 256 5.25G 0.27s
Official 2 256 2.98G 1.41s
This repo 2 256 2.97G 0.31s

Documentation

The main functionality is in checkpoint.py

import checkpoint
checkpoint.CheckpointFunction.apply(function, n, *args)

Parameters:

  • function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes (activation, hidden), function should correctly use the first input as activation and the second input as hidden.
  • n – number of inputs to the function
  • args – tuple containing inputs to the function AND parameters to optimize in the function. Note that the first n elements in this tuple should be ordered inputs to the function. Other elements are considered as parameters.

Returns:

  • Output of running function on inputs to the function

Note: We recommend using checkpointing with cp_BatchNorm2d instead of torch.nn.BatchNorm2d, to avoid accumulating the same batch norm statistics more than once.

DenseNet example

We provide an example of applying our checkpointing on memory efficient densenet. It only involves changing a few lines in the original implementation. (The original implementation uses PyTorch official checkpointing.)

# bn_function is a function containing conv1, norm1, relu1.
# naive no checkpointing: bottleneck_output = bn_function(*prev_features)
# official implementation: bottleneck_output = cp.checkpoint(bn_function, *prev_features)
args = prev_features + tuple(self.norm1.parameters()) + tuple(self.conv1.parameters())
# The parameters to optimize in the bn_function are tuple(self.norm1.parameters()) + tuple(self.conv1.parameters())
bottleneck_output = cp.CheckpointFunction.apply(bn_function, len(prev_features), *args)

Demo

python-fire is not required for checkpointing, but is required for the efficient densenet demo.

pip install fire
  • our checkpointing demo:
CUDA_VISIBLE_DEVICES=0,1 python cp_demo.py --efficient True --data cifar --save model --batch_size 256
  • the official implementation demo:
CUDA_VISIBLE_DEVICES=0,1 python original_demo.py --efficient True --data cifar --save model --batch_size 256

Environment

This code is tested with PyTorch 1.0.0.dev20181102

Speed tested on TITAN X (Pascal)

Full results

Method # GPU Batch Memory Time
Naive 1 256 9.93G 0.42s
Naive 2 4 0.65G 0.10s
Naive 2 256 5.25G 0.27s
Naive 2 512 9.93G 0.50s
Official 1 256 5.38G 0.52s
Official 1 512 10.1G 1.00s
Official 2 4 0.62G 1.40s
Official 2 256 2.98G 1.41s
Official 2 512 5.39G 1.53s
This repo 1 256 5.37G 0.50s
This repo 1 512 10.1G 0.97s
This repo 2 4 0.62G 0.13s
This repo 2 256 2.97G 0.31s
This repo 2 512 5.37G 0.58s

Credits

Part of our code in checkpoint.py and cp_BatchNorm2d.py is from https://github.com/pytorch/pytorch

The efficient densenet demo is taken from https://github.com/gpleiss/efficient_densenet_pytorch

pytorch-checkpoint's People

Contributors

csrhddlam avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.