Giter Club home page Giter Club logo

vit-attention-benchmark's Introduction

Vision Transformer Attention Benchmark

This repo is a collection of attention mechanisms in vision Transformers. Beside the re-implementation, it provides a benchmark on model parameters, FLOPs and CPU/GPU throughput.

Requirements

  • Pytorch 1.8+
  • timm
  • ninja
  • einops
  • fvcore
  • matplotlib

Testing Environment

  • NVIDIA RTX 3090
  • Intel® Core™ i9-10900X CPU @ 3.70GHz
  • Memory 32GB
  • Ubuntu 22.04
  • PyTorch 1.8.1 + CUDA 11.1

Setting

  • input: 14 x 14 = 196 tokens (1/16 scale feature maps in common ImageNet-1K training)
  • batch size for speed testing (images/s): 64
  • embedding dimension:768
  • number of heads: 12

Testing

For example, to test HiLo attention,

cd attentions/
python hilo.py

By default, the script will test models on both CPU and GPU. FLOPs is measured by fvcore. You may want to edit the source file as needed.

Outputs:

Number of Params: 2.2 M
FLOPs = 298.3 M
throughput averaged with 30 times
batch_size 64 throughput on CPU 1029
throughput averaged with 30 times
batch_size 64 throughput on GPU 5104

Supported Attentions

  • MSA: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. [Paper] [Code]
  • Cross Window: CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows. [Paper] [Code]
  • DAT: Vision Transformer with Deformable Attention. [Paper] [Code]
  • Performer: Rethinking Attention with Performers. [Paper] [Code]
  • Linformer: Linformer: Self-Attention with Linear Complexity. [Paper] [Code]
  • SRA: Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions. [Paper] [Code]
  • Local/Shifted Window: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. [Paper] [Code]
  • Focal: Focal Self-attention for Local-Global Interactions in Vision Transformers. [Paper] [Code]
  • XCA: XCiT: Cross-Covariance Image Transformers. [Paper] [Code]
  • QuadTree: QuadTree Attention for Vision Transformers. [Paper] [Code]
  • VAN: Visual Attention Network. [Paper] [Code]
  • HorNet: HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions. [Paper] [Code]
  • HiLo: Fast Vision Transformers with HiLo Attention. [Paper] [Code]

Single Attention Layer Benchmark

Name Params (M) FLOPs (M) CPU Speed GPU Speed Demo
MSA 2.36 521.43 505 4403 msa.py
Cross Window 2.37 493.28 325 4334 cross_window.py
DAT 2.38 528.69 223 3074 dat.py
Performer 2.36 617.24 181 3180 performer.py
Linformer 2.46 616.56 518 4578 linformer
SRA 4.72 419.56 710 4810 sra.py
Local Window 2.36 477.17 631 4537 shifted_window.py
Shifted Window 2.36 477.17 374 4351 shifted_window.py
Focal 2.44 526.85 146 2842 focal.py
XCA 2.36 481.69 583 4659 xca.py
QuadTree 5.33 613.25 72 3978 quadtree.py
VAN 1.83 357.96 59 4213 van.py
HorNet 2.23 436.51 132 3996 hornet.py
HiLo 2.20 298.30 1029 5104 hilo.py

Note: Each method has its own hyperparameters. For a fair comparison on 1/16 scale feature maps, all methods in the above table adopt their default 1/16 scale settings, as shown in their released code repo. For example, when dealing with 1/16 scale feature maps, HiLo in LITv2 adopt a window size of 2 and alpha of 0.9. Future works will consider more scales and memory benchmarking.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

vit-attention-benchmark's People

Contributors

hubhop avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

huaibovip

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.