Giter Club home page Giter Club logo

deit-tf's Introduction

DeiT-TF (Data-efficient Image Transformers)

This repository provides TensorFlow / Keras implementations of different DeiT [1] variants from Touvron et al. It also provides the TensorFlow / Keras models that have been populated with the original DeiT pre-trained params available from [2]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

As of today, all the TensorFlow / Keras variants of the DeiT models listed here are available in this repository.

Refer to the "Using the models" section to get started. You can also follow along with this tutorial: https://keras.io/examples/vision/deit/.

Updates

Table of contents

Conversion

TensorFlow / Keras implementations are available in vit/vit_models.py and vit/deit_models.py. Conversion utilities are in convert.py.

Models

Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/deit/1. You can fully inspect the architecture of the TF-Hub models like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/deit_tiny_patch16_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))

Results

Results are on ImageNet-1k validation set (top-1 accuracy).

model_name top1_acc(%) top5_acc(%) orig_top1_acc(%) orig_top5_acc(%)
0 deit_tiny_patch16_224 72.136 91.128 72.2 91.1
1 deit_tiny_distilled_patch16_224 74.522 91.896 74.5 91.9
2 deit_small_patch16_224 79.828 94.954 79.9 95
3 deit_small_distilled_patch16_224 81.172 95.414 81.2 95.4
4 deit_base_patch16_224 81.798 95.592 81.8 95.6
5 deit_base_patch16_384 82.894 96.234 82.9 96.2
6 deit_base_distilled_patch16_224 83.326 96.496 83.4 96.5
7 deit_base_distilled_patch16_384 85.238 97.172 85.2 97.2

Results can be verified with the code in i1k_eval. Original results were sourced from [2].

Using the models

Pre-trained models:

These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to visualize the attention maps for a given image.



Randomly initialized models:

from vit.model_configs import base_config
from vit.deit_models import ViTDistilled

import tensorflow as tf
 
distilled_tiny_tf_config = base_config.get_config(
    name="deit_tiny_distilled_patch16_224"
)
deit_tiny_distilled_patch16_224 = ViTDistilled(distilled_tiny_tf_config)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = deit_tiny_distilled_patch16_224(dummy_inputs)
print(deit_tiny_distilled_patch16_224.summary(expand_nested=True))

To initialize a network with say, 5 classes do:

with distilled_tiny_tf_config.unlocked():
    distilled_tiny_tf_config.num_classes = 5
deit_tiny_distilled_patch16_224 = ViTDistilled(distilled_tiny_tf_config)

To view different model configurations, refer to convert_all_models.py.

Training with DeiT

You can refer to the notebooks/deit-trainer.ipynb notebok to get a sense of how distillation is actually performed using DeiT. Additionally, that notebook also provides code in case you wanted to train a model from scratch instead of distillation.

References

[1] DeiT paper: https://arxiv.org/abs/2012.12877

[2] Official DeiT code: https://github.com/facebookresearch/deit

Acknowledgements

deit-tf's People

Contributors

sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

deit-tf's Issues

Did you train the tf Deit models from scratch?

Hi Sayak,

Thanks a lot for sharing this tensorflow version of Deit models.
I was wondering if you trained these tf Deit models from scratch to obtain the same level of accs as the original Pytorch Deit models?
Have you tried to convert the original Pytorch model weights into tf format, and load them into your tf Deit model? Were you able to get the same level of acc as the pytorch model in this way?
If you need to train these tf Deit models from scratch or fine tuning from the original pytorch weights, can you share your training or finetuning recipe as well?
Thanks a lot for your help again.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.