Giter Club home page Giter Club logo

optimized-stmae's Introduction

Optimized Spatiotemporal Masked Autoencoders (ST-MAEs)

A lean, optimized implementation of spatiotemporal masked autoencoders (ST-MAEs). The skeleton of the code is recycled from Facebook's ST-MAE repository with various simplifications. The following optimizations are implemented:

  • FlashAttention-2
  • torch.compile
  • fused AdamW
  • mixed precision training (torch.cuda.amp)
  • DDP for distributed training
  • selective decoding of videos

These optimizations allow us to achieve a very high training throughput: e.g. on merely 4 H100 GPUs, in roughly 1 week, we were able to complete over 160 epochs of training on Kinetics-700 (~536K videos) with a ViT-H encoder (~633M parameters) with 8x16x16=2048 spatiotemporal input "tokens" (i.e. 8 tokens in the temporal dimension and 16x16 tokens in the spatial dimensions) with a masking ratio of 90% and an effective batch size of 256 videos (64 videos on each GPU).

Dependence of model definitions on the timm library is also removed in this implementation, so the code is self-contained except for the standard libraries. The code was tested with pytorch==2.2.0 and torchvision==0.17.0.

Usage examples

  • Training: To train a spatiotemporal MAE model with a ViT-H/14 architecture from scratch on your data, use pretrain.py, e.g.:
python -u pretrain.py \
    --data_dirs DATA_DIRS \
    --datafile_dir DATAFILE_DIR \
    --save_prefix INFORMATIVE_SAVE_PREFIX \
    --output_dir OUTPUT_DIR \
    --model 'mae_vit_huge_patch14' \
    --batch_size_per_gpu 1 \
    --accum_iter 1 \
    --epochs 100000 \
    --num_frames 16 \
    --img_size 224 \
    --decoder_embed_dim 512 \
    --decoder_depth 4 \
    --pin_mem \
    --t_patch_size 2 \
    --repeat_aug 16 \
    --sampling_rate 8 \
    --lr 0.0001 \
    --weight_decay 0.05 \
    --mask_ratio 0.9 \
    --pred_t_dim 16 \
    --clip_grad 0.1

Here, DATA_DIRS is a list of directories containing the video files, DATAFILE_DIR is the directory where a .csv file containing all the training video file paths (optionally, with the corresponding class labels) will be saved, and OUTPUT_DIR is the directory where the checkpoints and training logs will be saved.

  • Finetuning on videos: To finetune a ViT-H/14 model on a downstream video recognition task, use finetune.py, e.g.:
python -u finetune.py \
    --train_dir TRAIN_DIR \
    --val_dir VAL_DIR \
    --datafile_dir DATAFILE_DIR \
    --save_prefix INFORMATIVE_SAVE_PREFIX \
    --output_dir OUTPUT_DIR \
    --finetune SPATIOTEMPORAL_MAE_CHECKPOINT \
    --num_classes 174 \
    --model 'vit_huge_patch14' \
    --batch_size_per_gpu 4 \
    --accum_iter 1 \
    --epochs 100000 \
    --num_frames 16 \
    --input_size 224 \
    --pin_mem \
    --t_patch_size 2 \
    --repeat_aug 1 \
    --sampling_rate 8 \
    --blr 0.0024 \
    --clip_grad 5.0 \
    --mixup 0 \
    --cutmix 0.0

Here, TRAIN_DIR and VAL_DIR are the directories containing the training and validation videos, respectively, and SPATIOTEMPORAL_MAE_CHECKPOINT is the path to the pretrained spatiotemporal MAE checkpoint the model is initialized with (use "" here if you would like to finetune the model from scratch without any pretraining).

  • Finetuning on images: To finetune a ViT-H/14 model on a downstream image recognition task (e.g. ImageNet), use finetune_on_image.py, e.g.:
python -u finetune_on_image.py \
    --train_data_path TRAIN_DATA_PATH \
    --val_data_path VAL_TRAIN_DATA_PATH \
    --save_prefix INFORMATIVE_SAVE_PREFIX \
    --output_dir OUTPUT_DIR \
    --finetune SPATIOTEMPORAL_MAE_CHECKPOINT \
    --num_classes 1000 \
    --model 'vit_huge_patch14' \
    --batch_size_per_gpu 4 \
    --accum_iter 1 \
    --epochs 100000 \
    --num_frames 16 \
    --input_size 224 \
    --pin_mem \
    --t_patch_size 2 \
    --blr 0.0024 \
    --clip_grad 5.0 \
    --mixup 0 \
    --cutmix 0.0

Here, TRAIN_DATA_PATH and VAL_TRAIN_DATA_PATH are the directories containing the training and validation images, respectively, and SPATIOTEMPORAL_MAE_CHECKPOINT is the path to the pretrained spatiotemporal MAE checkpoint the model is initialized with. This script will effectively make a static video clip for each image by repeating the image 16 times (num_frames). This allows us to use the pretrained spatiotemporal MAE model as is without any modifications in the architecture.

optimized-stmae's People

Contributors

eminorhan avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.