Giter Club home page Giter Club logo

fastvit's Introduction

FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization

This is the official repository of

FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization. Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel, Anurag Ranjan. ICCV 2023

arxiv webpage

FastViT Performance

All models are trained on ImageNet-1K and benchmarked on iPhone 12 Pro using ModelBench app.

Setup

conda create -n fastvit python=3.9
conda activate fastvit
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt

Usage

To use our model, follow the code snippet below,

import torch
import models
from timm.models import create_model
from models.modules.mobileone import reparameterize_model

# To Train from scratch/fine-tuning
model = create_model("fastvit_t8")
# ... train ...

# Load unfused pre-trained checkpoint for fine-tuning
# or for downstream task training like detection/segmentation
checkpoint = torch.load('/path/to/unfused_checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
# ... train ...

# For inference
model.eval()      
model_inf = reparameterize_model(model)
# Use model_inf at test-time

FastViT Model Zoo

Image Classification

Models trained on ImageNet-1K

Model Top-1 Acc. Latency Pytorch Checkpoint (url) CoreML Model
FastViT-T8 76.2 0.8 T8(unfused) fastvit_t8.mlpackage.zip
FastViT-T12 79.3 1.2 T12(unfused) fastvit_t12.mlpackage.zip
FastViT-S12 79.9 1.4 S12(unfused) fastvit_s12.mlpackage.zip
FastViT-SA12 80.9 1.6 SA12(unfused) fastvit_sa12.mlpackage.zip
FastViT-SA24 82.7 2.6 SA24(unfused) fastvit_sa24.mlpackage.zip
FastViT-SA36 83.6 3.5 SA36(unfused) fastvit_sa36.mlpackage.zip
FastViT-MA36 83.9 4.6 MA36(unfused) fastvit_ma36.mlpackage.zip

Models trained on ImageNet-1K with knowledge distillation.

Model Top-1 Acc. Latency Pytorch Checkpoint (url) CoreML Model
FastViT-T8 77.2 0.8 T8(unfused) fastvit_t8.mlpackage.zip
FastViT-T12 80.3 1.2 T12(unfused) fastvit_t12.mlpackage.zip
FastViT-S12 81.1 1.4 S12(unfused) fastvit_s12.mlpackage.zip
FastViT-SA12 81.9 1.6 SA12(unfused) fastvit_sa12.mlpackage.zip
FastViT-SA24 83.4 2.6 SA24(unfused) fastvit_sa24.mlpackage.zip
FastViT-SA36 84.2 3.5 SA36(unfused) fastvit_sa36.mlpackage.zip
FastViT-MA36 84.6 4.6 MA36(unfused) fastvit_ma36.mlpackage.zip

Latency Benchmarking

Latency of all models measured on iPhone 12 Pro using ModelBench app. For further details please contact James Gabriel and Jeff Zhu. All reported numbers are rounded to the nearest decimal.

Training

Image Classification

Dataset Preparation

Download the ImageNet-1K dataset and structure the data as follows:

/path/to/imagenet-1k/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  validation/
    class1/
      img3.jpeg
    class2/
      img4.jpeg

To train a variant of FastViT model, follow the respective command below:

FastViT-T8
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-T12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-S12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-SA12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-SA24
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.05 \
--distillation-type "hard"
FastViT-SA36
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1 \
--distillation-type "hard"
FastViT-MA36
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.35

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2 \
--distillation-type "hard"

Evaluation

To run evaluation on ImageNet, follow the example command below:

FastViT-T8
# Evaluate unfused checkpoint
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8.pth.tar

# Evaluate fused checkpoint
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar \
--use-inference-mode

Model Export

To export a coreml package file from a pytorch checkpoint, follow the example command below:

FastViT-T8
python export_model.py --variant fastvit_t8 --output-dir /path/to/save/exported_model \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar

Citation

@inproceedings{vasufastvit2023,
  author = {Pavan Kumar Anasosalu Vasu and James Gabriel and Jeff Zhu and Oncel Tuzel and Anurag Ranjan},
  title = {FastViT:  A Fast Hybrid Vision Transformer using Structural Reparameterization},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  year = {2023}
}

Acknowledgements

Our codebase is built using multiple opensource contributions, please see ACKNOWLEDGEMENTS for more details.

fastvit's People

Contributors

pavank-apple avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.