Giter Club home page Giter Club logo

mnist-benchmark's Introduction

mnist-benchmark

This repository implement models described in recent computer vision literature, with a focus on a simple classification task with a classical dataset (MNIST). Three base models are explored: Spatial transformer networks, vision transformers, and SpinalNets. We also implement new variations for two of three of these models, by replacing standard convolutional layers by CoordConv layers.

  • Spatial transformer networks (STN)
  • Spatial transformer networks + CoordConv layers
  • Vision transformers
  • SpinalNet
  • SpinalNet + STN + CoordConv layers

A complete run of the experiments together with results, comments and references are available in MNIST_benchmarks.ipynb. These can also be reproduced in the following Colab notebook:

Open In Colab

A standalone script run_experiments.py is also provided to reproduce the experiments. A few dependencies are necessary and listed in requirements.txt.

usage: run_experiment.py [-h] [--device {gpu,cpu}] [--workers WORKERS]
                         [--bs BS] [--maxepochs MAX_EPOCHS]
                         [--patience PATIENCE] [--mindelta MIN_DELTA]
                         [--model {stn,stncoordconv,vit,spinal,spinalstn}]
                         [--localization] [--lr LR] [--logs LOGPATH]

MNIST-benchmarks

optional arguments:
  -h, --help            show this help message and exit
  --device {gpu,cpu}    Device on which to run the experiments. (default: cpu)
  --workers WORKERS     Number of workers for dataloaders. (default: 2)
  --bs BS               Batch size. (default: 64)
  --maxepochs MAX_EPOCHS
                        Maximum number of epochs to run the experiment for.
                        (default: 20)
  --patience PATIENCE   Number of epochs with no improvement before triggering
                        early stopping. (default: 5)
  --mindelta MIN_DELTA  Required improvement in the validation loss for early
                        stopping. (default: 0.005)
  --model {stn,stncoordconv,vit,spinal,spinalstn}
                        Type of model to train. (default: stn)
  --localization        Whether to use CoordConv in the localization network.
                        (default: False)
  --lr LR               Learning rate for SGD. (default: 0.01)
  --logs LOGPATH        Directory to store tensorboard logs. (default: logs/)

Tensorboard is used to save the training and validation logs and metrics. By default, the logs are saved in logs/. To launch tensorboard, use the following line. More details on tensorboard are found here:

tensorboard --logdir=logs/ --port <port> --host <host>

mnist-benchmark's People

Contributors

manuel-munoz-aguirre avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.