Giter Club home page Giter Club logo

rezero's Introduction

ReZero for Deep Neural Networks

ReZero is All You Need: Fast Convergence at Large Depth.*Uncertainty in AI (UAI), 2021*.\Thomas Bachlechner*, Bodhisattwa Prasad Majumder*, Huanru Henry Mao*, Garrison W. Cottrell, Julian McAuley (* denotes equal contributions)

This repository contains the ReZero-Transformer implementation from the paper. It matches Pytorch's Transformer and can be easily used as a drop-in replacement.

Quick Links:

Abstract

Deep networks have enabled significant performance gains across domains, but they often suffer from vanishing/exploding gradients. This is especially true for Transformer architectures where depth beyond 12 layers is difficult to train without large datasets and computational budgets. In general, we find that inefficient signal propagation impedes learning in deep networks. In Transformers, multi-head self-attention is the main cause of this poor signal propagation. To facilitate deep signal propagation, we propose ReZero, a simple change to the architecture that initializes an arbitrary layer as the identity map, using a single additional learned parameter per layer. We apply this technique to language modeling and find that we can easily train ReZero-Transformer networks over a hundred layers. When applied to 12 layer Transformers, ReZero converges 56% faster on enwiki8. ReZero applies beyond Transformers to other residual networks, enabling 1,500% faster convergence for deep fully connected networks and 32% faster convergence for a ResNet-56 trained on CIFAR 10.

Installation

Simply install from pip:

pip install rezero

Pytorch 1.4 or greater is required.

Usage

We provide custom ReZero Transformer layers (RZTX).

For example, this will create a Transformer encoder:

import torch
import torch.nn as nn
from rezero.transformer import RZTXEncoderLayer

encoder_layer = RZTXEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)

This will create a Transformer decoder:

import torch
import torch.nn as nn
from rezero.transformer import RZTXDecoderLayer

decoder_layer = RZTXDecoderLayer(d_model=512, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
memory = torch.rand(10, 32, 512)
tgt = torch.rand(20, 32, 512)
out = transformer_decoder(tgt, memory)

Make sure norm argument is left as None as to not use LayerNorm in the Transformer.

See https://pytorch.org/docs/master/nn.html#torch.nn.Transformer for details on how to integrate customer Transformer layers to Pytorch.

Tutorials

  1. Training 128 layer ReZero Transformer on WikiText-2 language modeling
  2. Training 10,000 layer ReZero neural network on CIFAR-10 data

Watch for more tutorials in this space.

Citation

If you find rezero useful for your research, please cite our paper:

@inproceedings{BacMajMaoCotMcA20,
    title = "ReZero is All You Need: Fast Convergence at Large Depth",
    author = "Bachlechner, Thomas  and
      Majumder, Bodhisattwa Prasad
      Mao, Huanru Henry and
      Cottrell, Garrison W. and
      McAuley, Julian",
    booktitle = "arXiv",
    year = "2020",
    url = "https://arxiv.org/abs/2003.04887"
}

rezero's People

Contributors

calclavia avatar majumderb avatar mpariente avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

rezero's Issues

Does it work in not so deep architectures?

Thanks for your greate job.Do you have any experiment of ReZero applied in different layers of transformers, like 1 layer Transformer layer and it performance , 2 layer Transformer layers and it performance, and so on.Does it make convergence faster in not so deep net?Thank you.

resweight is almost 0

Hi ,thanks for your work, I use rezero method to train 32 layers transformer, I found that starting from the 20th layer, the resweight is almost 0, (layer 0 is data input, layer 31 is data output, dec_attn is attention layer, pos_ff is feed forward layer, resweight is the coefficient of xi + resweight*sublayer(xi) ), if resweight is almost 0, then this layer didn't work.
Why does this happen?
Thanks for your help

layers.0.dec_attn.resweight tensor([0.0962])
layers.0.pos_ff.resweight tensor( [0.0908])
layers.1.dec_attn.resweight tensor([0.1198])
layers.1.pos_ff.resweight tensor( [0.1206])
layers.2.dec_attn.resweight tensor([0.1403] )
layers.2.pos_ff.resweight tensor( [0.1274] )
layers.3.dec_attn.resweight tensor([0.1621] )
layers.3.pos_ff.resweight tensor( [0.1263] )
layers.4.dec_attn.resweight tensor([0.2211] )
layers.4.pos_ff.resweight tensor( [0.1438] )
layers.5.dec_attn.resweight tensor([0.2545] )
layers.5.pos_ff.resweight tensor( [0.1415] )
layers.6.dec_attn.resweight tensor([0.3898] )
layers.6.pos_ff.resweight tensor( [0.1338] )
layers.7.dec_attn.resweight tensor([0.2653] )
layers.7.pos_ff.resweight tensor( [0.1012] )
layers.8.dec_attn.resweight tensor([-0.0499] )
layers.8.pos_ff.resweight tensor( [-0.0796] )
layers.9.dec_attn.resweight tensor([0.0203] )
layers.9.pos_ff.resweight tensor( [-0.0963] )
layers.10.dec_attn.resweight tensor([-0.0249] )
layers.10.pos_ff.resweight tensor( [-0.0963] )
layers.11.dec_attn.resweight tensor([0.0133] )
layers.11.pos_ff.resweight tensor( [0.0927] )
layers.12.dec_attn.resweight tensor([-0.0243] )
layers.12.pos_ff.resweight tensor( [0.0958] )
layers.13.dec_attn.resweight tensor([0.0287] )
layers.13.pos_ff.resweight tensor( [-0.0868] )
layers.14.dec_attn.resweight tensor([-0.0148] )
layers.14.pos_ff.resweight tensor( [0.0814] )
layers.15.dec_attn.resweight tensor([-0.0198] )
layers.15.pos_ff.resweight tensor( [-0.0581] )
layers.16.dec_attn.resweight tensor([0.0174] )
layers.16.pos_ff.resweight tensor( [-0.0743] )
layers.17.dec_attn.resweight tensor([-0.0107] )
layers.17.pos_ff.resweight tensor( [-0.0619] )
layers.18.dec_attn.resweight tensor([ -0.0001] )
layers.18.pos_ff.resweight tensor( [-0.0001] )
layers.19.dec_attn.resweight tensor([0.0061] )
layers.19.pos_ff.resweight tensor( [-0.0000] )
layers.20.dec_attn.resweight tensor([0.0054] )
layers.20.pos_ff.resweight tensor( [-0.0001] )
layers.21.dec_attn.resweight tensor([-0.0001] )
layers.21.pos_ff.resweight tensor( [-0.0000] )
layers.22.dec_attn.resweight tensor([-0.0036] )
layers.22.pos_ff.resweight tensor( [0.0001] )
layers.23.dec_attn.resweight tensor([0.0042] )
layers.23.pos_ff.resweight tensor( [-0.0001] )
layers.24.dec_attn.resweight tensor([0.0017] )
layers.24.pos_ff.resweight tensor( [-0.0000] )
layers.25.dec_attn.resweight tensor([-0.0037] )
layers.25.pos_ff.resweight tensor( [-0.0003] )
layers.26.dec_attn.resweight tensor([0.0003] )
layers.26.pos_ff.resweight tensor( [0.0001] )
layers.27.dec_attn.resweight tensor([0.0004] )
layers.27.pos_ff.resweight tensor( [-0.0001] )
layers.28.dec_attn.resweight tensor([-0.0007] )
layers.28.pos_ff.resweight tensor( [0.0001] )
layers.29.dec_attn.resweight tensor([0.0002] )
layers.29.pos_ff.resweight tensor( [-0.0000] )
layers.30.dec_attn.resweight tensor([0.0008] )
layers.30.pos_ff.resweight tensor( [0.0000] )
layers.31.dec_attn.resweight tensor([0.0008] )
layers.31.pos_ff.resweight tensor( [-0.0000] )

Is ReZero applicable to fine-tuning?

According to the paper, ReZero initializes each layer to perform identity operation.
It seems that ReZero is designed for training networks from scratch. I wonder is it applicable to fine-tuning and improve convergence?

rezero with norm

great work! In your paper, rezero shows two main benefits both in deeper learning and faster convergence. Various forms of norm and residual connections are listd In Table 1. I am curious about the form of rezero with norm, e.g., x(i+1) = x(i) + aF(Norm(x(i))). Will it be worse or better?
Thanks

can rezero be applied to cnn ?

hello, i see your good demo in transformer and fully connected networks.
i wander, can it be applied to convolutional neural networks, is there any demo project ?

thanks .

weight decay for the resweight?

Hello, I read the paper, and it is interesting to me.
I have a question.

Many implements including Huggingface exclude LayerNorm and biases when decaying weights for convergence.
(huggingface/transformers#492)
Is it helpful to exclude the resweight parameters when decaying weights??

does rezero work in machine translation tasks?

hi, guys,
did you have experiments in machine translation tasks? e.g. WMT ende or enfr
I experimented with rezero in my machine translation task, while training with fp16, using rezero brings loss scale reaching minimum , with/without layernorm, does that make sense?

Sry guys but your paper is not worth more than zero :)

The thing is I see 1 genius man, Jürgen Schmidhuber, who invented Highway Networks in May 2015, and here are his works and subsequent works which you failed to cite:

Highway Networks (2015 May & Nov v2)
https://arxiv.org/abs/1505.00387

Training Very Deep Networks (2015 Jul & Nov v2)
https://arxiv.org/pdf/1507.06228.pdf

And ResNet is only a special case of HighwayNet, when the 2 gates are constant 1.
Highway and Residual Networks learn Unrolled Iterative Estimation (2016 Dec & 2017 Mar v2&v3)
https://arxiv.org/abs/1612.07771

And here instead of using a gate tensor as in HighwayNet they just use a scalar multiplier like you,
but in 2016, not in 2020... your scientific lagg is (significantly) more than zero.
Learning Identity Mappings with Residual Gates (2016 Nov & Dec v2)
https://arxiv.org/pdf/1611.01260v2.pdf

You did not cite either the Gated ResNet (which actually cite the HighwayNet) from 2016, neither the HighwayNet from 2015, but you cite the Kaiming He's ResNet (which also cite the HighwayNet).

Relationship between ReZero and Zero gamma trick

Hello! Thanks for your interesting work and useful codes.

I have one small question. In table 1 of the paper, the formulation of Residual Network + Pre-Norm is . From my understanding, the corresponding formulation of Residual Network + Post-Norm should be which is also the real practice in ResNet. But the paper referred to a different formulation. Is this a typo or do I understand something wrong?

In this formulation, a trick called zero gamma trick (setting gamma=0 for every batch normalization going back to the main branch) is commonly used [1,2]. Similar invariant Fixup Initialization [3] also benefits from this idea and shows the ability to train very deep neural network. The trick is used by both PyTorch code link and TensorFlow code link ResNet implementations. What is the relationship between ReZero and Zero gamma trick? Thanks!

[1] Goyal et al. Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour.
[2] He et al. Bag of Tricks for Image Classification with Convolutional Neural Networks.
[3] Zhang et al. Fixup Initialization: Residual Learning Without Normalization.

The order of dropout and *resweight

Hi, nice work.

I notice that in encoder layer, you multiply by resweight and then do dropout. But in decoder layer, you do dropout and then multiply by resweight. Does the order of dropout and *resweight matter?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.