Giter Club home page Giter Club logo

attention-rank-collapse's Introduction

Attention is not all you need, pure attention loses rank doubly exponentially with depth.

Yihe Dong, Jean-Baptiste Cordonnier, Andreas Loukas.

In this work, we find that pure attention decays in rank doubly exponentially with respect to depth. We analyze how MLPs and skip connections counteract this decay. Our paper contains further details. This repository contains the code for our experiments.

Requirements

To install a working environment:

conda create --name rank-collapse python=3.8
conda activate rank-collapse
pip install git+git://github.com/huggingface/transformers.git@12d7624199e727f37bef7f53d527df7fabdb1fd6
conda install pytorch torchvision torchaudio -c pytorch
conda install -c anaconda scipy scikit-learn black isort matplotlib flake8

Quick start

This repo contains the code to reproduce our experiments. Additional options can be viewed by running python run_sort --help or checking out arguments.py.

Attention in several architectures

This notebook contains the driver code to study the interplaying effects of pure attention, skip connections, and MLPs in several common attention-based architectures.

Sorting

This task learns to sort sequences:

python run_sort.py --width 2 --depth 6 --hidden_dim 48 --seed 2 --num_labels 10 --seq_len 8 --n_epochs 65 --path_len 0 --n_paths 5 --n_train_data 1000 --n_repeat 5 --n_eval_data 200 --no_sub_path

Convex hull

This task learns to predict the convex hull of a set of points on the plane:

python run_convex_hull.py --width 3 --depth 6 --seq_len 10 --hidden_dim 84 --seed 2 --num_labels 2 --n_epochs 70 --path_len 0 --n_paths 5 --n_train_data 10000 --n_repeat 5 --ffn2 --n_eval_data 300 --no_sub_path

Memorization

This task learns to memorize the randomly assigned token labels to natural language:

python run_memorization.py --model_name_or_path bert-base-uncased --task_name RTE --do_train --data_dir data/RTE \
    --max_seq_length 128 --per_device_eval_batch_size=32 --per_device_train_batch_size=8 --learning_rate 3e-4 \
    --num_train_epochs 50 --output_dir snap/RTE --overwrite_output_dir --width 2 --depth 6 --hidden_dim 250 \
    --n_repeat 5 --n_paths 5 $1 --n_train_data 500 --no_sub_path --path_len 0

Circle

This experiment recurrently applies a self-attention layer in learning two circular arcs, illustrating rank collapse of the pure attention model, and how skip connections and MLPs counteract it. To run:

python run_circle.py --width 2 --depth 1 --num_labels 2 --seq_len 10 --n_epochs 70 --n_train_data 1000 --n_repeat 2  --n_eval_data 300 --hidden_dim 32

The above can be modified to add skip connections or MLP to the self-attention network with the --circle_skip or --do_mlp options, respectively.

Paths distribution

This notebook studies the distribution of paths in some commonly used attention-based models.

Citation

If you find our work useful, please cite as:

@article{rankCollapse2021,
  title         = {Attention is not all you need, pure attention loses rank doubly exponentially with depth},
  author        = {Dong, Yihe and Cordonnier, Jean-Baptiste and Loukas, Andreas},
  url       	= {https://arxiv.org/abs/2103.03404},
  year          = {2021}
  }

attention-rank-collapse's People

Contributors

twistedcubic avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.