Giter Club home page Giter Club logo

clip-training's Introduction

CLIP training

This repository contains code to train CLIP on MS-COCO captions. Can be easily modified to train on other multi-modal datasets (OpenImages, Conceptual captions, ...).

Requirements

To setup environment

# create new env clip_train
$ conda create -n clip_train python=3.8.5

# activate clip_train
$ conda activate clip_train

# install pytorch, torchvision
$ conda install pytorch==1.7.0 torchvision==0.8.0 cudatoolkit=10.2 -c pytorch

# install other dependencies
$ pip install -r requirements.txt

Training

Preparing training dataset

MS-COCO training set images and their captions are used for training the CLIP model. To download the dataset :

# create directory in data/
$ mkdir data/mscoco

# download images
$ wget http://images.cocodataset.org/zips/train2017.zip -O data/mscoco/train2017.zip
$ unzip data/mscoco/train2017.zip -d data/mscoco


# download annotations 
$ wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O data/mscoco/annotations_trainval2017.zip
$ unzip data/mscoco/annotations_trainval2017.zip -d data/mscoco

To check and update training parameters, model config and dataset paths please see the following config files :

trainer/train_config.yaml   # training parameters
model/model_config.yaml     # CLIP model config
dataloader/data_config.yaml # training dataset path

To train :

Take dataset paths from 'dataloader/data_config.yaml'

$ python train.py 

OR, give dataset path as cl args

$ python train.py --train_img_dir <path to training images directory> --train_annotation_file <path to annotation file>

Training setting :

  • Model config : Since MS-COCO is relatively small dataset, I used ResNet50 as image encoder instead of Vision Transformer. Further, I also reduced the number of transformer layers to 6 in text encoder. Detailed model config is here : model_config.yaml

  • Batch size : 256. I trained using 4 GTX1080 GPUs (64 batch size per gpu).

  • Optimizer : Adam optimizer with weight decay.

  • Scheduler : Cosine Scheduler with warmup for first 20% of gradient update steps. Detailed training config is here : train_config.yaml

  • Temperature parameter clipping : Added temperature clipping as mentioned in the paper for training stability. The learnable temperature parameter is clipped to prevent scaling the logits by more than 100.

Zero-shot classification :

For zero-shot classification, first all class names are converted into sentences using templates (like "a photo of a {class name}") and their text embeddings are computed using CLIP. Then to classify an image, first image embedding is computed using CLIP and then its cosine similarity with all the class sentences embeddings is computed to predict the class with the highest cosine similarity.

Zero-shot demo :

Trained weights :

  • Download trained checkpoint from google drive : link
  • Or use gdown to download it :
    # first install gdown
    $ pip install gdown
    
    # then download trained weights at 'saved_checkpoints/trained_checkpoint.pt'
    $ mkdir saved_checkpoints
    $ gdown --id 1BVEY4WeFmQb3wv0A6RaLyVjnc7qmChH2 -O saved_checkpoints/trained_checkpoint.pt  
    

To classify image(s) into CIFAR100 classes, run the following

# to classify a single image
$ python zero_shot_demo.py --checkpoint_path <path_to_trained_checkpoint.pt> --img_path <path_to_img.jpg> --show_prediction

# to classify all images images in a directory
$ python zero_shot_demo.py --checkpoint_path <path_to_trained_checkpoint.pt> --img_dir <path_to_img_directory> --show_prediction

# --show_prediction flag is to save a prediction figure with class probabilities
# NOTE : Please put even number of images in img_directory to get a nice prediction figure

Example to run zero-shot demo:

# first put trained weights at saved_checkpoints/trained_checkpoint.pt 

# for single image
$ python zero_shot_demo.py --checkpoint_path saved_checkpoints/trained_checkpoint.pt --img_path test_images/bicycle.jpeg --show_prediction

# for an image directory
$ python zero_shot_demo.py --checkpoint_path saved_checkpoints/trained_checkpoint.pt --img_dir test_images --show_prediction

# view prediction figure in "demo_output" directory

Zero-shot evaluation on vision datasets + observations :

For evaluation results and instructions on how to run eval code, check this : Observations and Eval results

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.