Giter Club home page Giter Club logo

pytorch-pretrained-vit's Introduction

ViT PyTorch

Quickstart

Install with pip install pytorch_pretrained_vit and load a pretrained ViT with:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Or find a Google Colab example here.

Overview

This repository contains an op-for-op PyTorch reimplementation of the Visual Transformer architecture from Google, along with pre-trained models and examples.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects.

At the moment, you can easily:

  • Load pretrained ViT models
  • Evaluate on ImageNet or your own data
  • Finetune ViT on your own dataset

(Upcoming features) Coming soon:

  • Train ViT from scratch on ImageNet (1K)
  • Export to ONNX for efficient inference

Table of contents

  1. About ViT
  2. About ViT-PyTorch
  3. Installation
  4. Usage
  5. Contributing

About ViT

Visual Transformers (ViT) are a straightforward application of the transformer architecture to image classification. Even in computer vision, it seems, attention is all you need.

The ViT architecture works as follows: (1) it considers an image as a 1-dimensional sequence of patches, (2) it prepends a classification token to the sequence, (3) it passes these patches through a transformer encoder (like BERT), (4) it passes the first token of the output of the transformer through a small MLP to obtain the classification logits. ViT is trained on a large-scale dataset (ImageNet-21k) with a huge amount of compute.

About ViT-PyTorch

ViT-PyTorch is a PyTorch re-implementation of ViT. It is consistent with the original Jax implementation, so that it's easy to load Jax-pretrained weights.

At the same time, we aim to make our PyTorch implementation as simple, flexible, and extensible as possible.

Installation

Install with pip:

pip install pytorch_pretrained_vit

Or from source:

git clone https://github.com/lukemelas/ViT-PyTorch
cd ViT-Pytorch
pip install -e .

Usage

Loading pretrained models

Loading a pretrained model is easy:

from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)

Details about the models are below:

Name * Pretrained on * Finetuned on *Available? *
B_16 ImageNet-21k -
B_32 ImageNet-21k -
L_16 ImageNet-21k - -
L_32 ImageNet-21k -
B_16_imagenet1k ImageNet-21k ImageNet-1k
B_32_imagenet1k ImageNet-21k ImageNet-1k
L_16_imagenet1k ImageNet-21k ImageNet-1k
L_32_imagenet1k ImageNet-21k ImageNet-1k

Custom ViT

Loading custom configurations is just as easy:

from pytorch_pretrained_vit import ViT
# The following is equivalent to ViT('B_16')
config = dict(hidden_size=512, num_heads=8, num_layers=6)
model = ViT.from_config(config)

Example: Classification

Below is a simple, complete example. It may also be found as a Jupyter notebook in examples/simple or as a Colab Notebook.

import json
from PIL import Image
import torch
from torchvision import transforms

# Load ViT
from pytorch_pretrained_vit import ViT
model = ViT('B_16_imagenet1k', pretrained=True)
model.eval()

# Load image
# NOTE: Assumes an image `img.jpg` exists in the current directory
img = transforms.Compose([
    transforms.Resize((384, 384)), 
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])(Image.open('img.jpg')).unsqueeze(0)
print(img.shape) # torch.Size([1, 3, 384, 384])

# Classify
with torch.no_grad():
    outputs = model(img)
print(outputs.shape)  # (1, 1000)

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Credit

Other great repositories with this model include:

Contributing

If you find a bug, create a GitHub issue, or even better, submit a pull request. Similarly, if you have questions, simply post them as GitHub issues.

I look forward to seeing what the community does with these models!

pytorch-pretrained-vit's People

Contributors

lukemelas avatar

Forkers

mschrimpf

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.