Giter Club home page Giter Club logo

mushroom's Introduction

Image Classification using ViT

Image classification using Vision Transformer. The implementation was done using pytorch (version 1.11.0).

The dataset used is the one from Fungi Classification FGVC5 competition, workshop at CVPR 2018. The dataset describes 1394 different classes of mushrooms and it is split into training, validation, and test. Each of these datasets contains 85578, 4182, and 9758 images respectively.

This script read the annotation from json files using the pycocotools COCO class, which makes it really easy and simple.

In the end, these codes can be easly used to train, evaluate, and test a ViT for a different image classification task.

Training

To train the model type on terminal:

python train.py 

where you can add these inputs:

  • Training paramenters:

    • -tj or --train_json: train json file location (default: ./annotations/train.json)
    • -vj or --val_json: val json file location (default: ./annotations/val.json)
    • -g or --gpu: GPU position (default: 0)
    • -is or --image_shape: new image shape (default: (224, 224))
    • -bs or --batch_size: batch size (default: 32)
    • -nw or --num_workers: num workers (default: 2)
    • -lr or --learning_rate: learning rate (default: 0.001)
    • -wd or --weight_decay: weight decay (default: 0.1)
    • -ne or --num_epochs: number of epochs (default: 100)
    • -cp or --checkpoint_path: checkpoint path (default: ./model/model.pt)
    • -lm or --load_model: load pre-trained model from prevoius training (default: False)
  • Augmentation parameters:

    • -pa or --perc_augmentation: augmentation percentage (default: 0.7)
    • -phf or --perc_horiz_filp: random horzontal flip percentage (default: 0.5)
    • -pvf or --perc_vert_filp: random vertical flip percentage (default: 0.5)
    • -pr or --perc_rotation: random rotation percentage (default: 0.5)
    • -rr or --rotation_range: random rotation range (default: 60, so [-60, 60) degrees)
    • -pb or --perc_bright: random brightness percentage (default: 0.5)
    • -gr or --gamma_range: random gamma range (default: 0.2, so [1-0.2, 1+0.2))
  • Model parameters:

    • -ps or --patch_size: image patch size (default: 16)
    • -nl or --num_layers: number of encoder layers (default: 12)
    • -nh or --num_heads: number of heads for the MHA layer (default: 12)
    • -hd or --hidden_dim: hidden dimension (default: 768)
    • -md or --mlp_dim: mlp dimension (default: 3072)
    • -d or --dropout: dropout rate (default: 0.2)
    • -da or --attention_dropout: attention dropout rate (default: 0.2)

This script saves your model in the checkpoint path with also its parameters. So, once you load it (load_model = True), you do not need to specify its parameters (i.e. evaulation and test phase). Moreover, it saves the training and validation loss and accuray plot in the history.png file.

Evaluation

To evaluate the model type on terminal:

python evaluate.py 

where you can add these inputs:

  • -vj or --val_json: train json file location (default: ./annotations/val.json)
  • -g or --gpu: GPU position (default: 0)
  • -is or --image_shape: new image shape (default: (224, 224))
  • -cp or --checkpoint_path: checkpoint path (default: ./model/model.pt)

This script calculates the accuracy of the input dataset using the pre-trained model saved in the checkpoint path. It breaks if checkpoint_path does not exist.

Testing

To test the model type on terminal:

python test.py 

where you can add these inputs:

  • -tj or --val_json: train json file location (default: ./annotations/val.json)
  • -cj, or --classes_json: classes dictionary location (default ./classes_id_names.json)
  • -g or --gpu: GPU position (default: 0)
  • -is or --image_shape: new image shape (default: (224, 224))
  • -cp or --checkpoint_path: checkpoint path (default: ./model/model.pt)

This scripts returns a .csv file which contains 2 columns: id and predicitons.

The id column contains the image IDs of the dataset, whereas predicitons the top 3 predictions of the model. It breaks if checkpoint_path does not exist.

mushroom's People

Contributors

mdciri avatar

Stargazers

Sukonya Phukan avatar  avatar

Watchers

 avatar

Forkers

mycosavant

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.