Giter Club home page Giter Club logo

capsulenet's Introduction

๐Ÿ’Š CapsuleNet ๐Ÿ’Š

A PyTorch implementation of CapsuleNet as described in "Dynamic Routing Between Capsules" by Hinton et al.

Includes decoder and pretrained weights for MNIST and Fashion MNIST in checkpoints/.

Instructions

This project uses conda for package management, install miniconda3 here.

Create an environment with all the dependancies:

conda env create -f environment.yml

Activate with

source activate capsnet

To train:

โ–ถ python train.py --help
usage: CapsNet [-h] [--epochs EPOCHS] [--data_path DATA_PATH]
               [--batch_size BATCH_SIZE] [--use_gpu] [--lr LR]
               [--log_interval LOG_INTERVAL] [--visdom] [--dataset DATASET]
               [--load_checkpoint LOAD_CHECKPOINT]
               [--checkpoint_interval CHECKPOINT_INTERVAL]
               [--checkpoint_dir CHECKPOINT_DIR] [--gen_dir GEN_DIR]

Example of CapsNet

optional arguments:
  -h, --help            show this help message and exit
  --epochs EPOCHS
  --data_path DATA_PATH
  --batch_size BATCH_SIZE
  --use_gpu
  --lr LR               ADAM learning rate (0.01)
  --log_interval LOG_INTERVAL
                        number of batches between logging
  --visdom              Whether or not to use visdom for plotting progrss
  --dataset DATASET     The dataset to train on, currently supported: MNIST,
                        Fashion MNIST
  --load_checkpoint LOAD_CHECKPOINT
                        path to load a previously trained model from
  --checkpoint_interval CHECKPOINT_INTERVAL
                        path to load a previously trained model from
  --checkpoint_dir CHECKPOINT_DIR
                        dir to store checkpoints in
  --gen_dir GEN_DIR     folder to store generated images in
(ml)
python train.py --visdom --checkpoint_interval=1 --epochs=10

Visdom for graphing progress

To start the visdom server:

python -m visdom.server
# Now running on http://localhost:8097/

Then run with the --visdom flag

To run using Fashion mnist

Download the dataset from here, place them in a subdirectory of folder entitled raw and run with

python train.py --data_path=<path to download> ...

your fashion-mnist folder should look like this

โ–ถ tree ~/data/fashion-mnist
/home/erikreppel/data/fashion-mnist
โ”œโ”€โ”€ processed
โ”‚ย ย  โ”œโ”€โ”€ test.pt
โ”‚ย ย  โ””โ”€โ”€ training.pt
โ””โ”€โ”€ raw
    โ”œโ”€โ”€ t10k-images-idx3-ubyte
    โ”œโ”€โ”€ t10k-labels-idx1-ubyte
    โ”œโ”€โ”€ train-images-idx3-ubyte
    โ””โ”€โ”€ train-labels-idx1-ubyte

2 directories, 6 files

The processed folder is created after training starts.

The PyTorch mnist dataset class will handle pre-processing.

Results

Dataset Epochs Test loss Test accuracy
MNIST 10 0.04356 98.803
Fashion MNIST 10 0.19429 86.580
MNIST 50 0.03029 99.011
Fashion MNIST 50 0.16416 88.904

Using:

Conv layer:
- input channels: 1
- output channels: 256
- stride: 1
- kernel size: 9x9

Capsule layer 1:
- 8 capsules of size 1152
- input channels: 256
- output channels: 32

Capsule layer 2:
- 10 capsules of size 16 (10 classes in mnist)
- input channels: 32
- 3 iterations of the routing algorithm

Generated images

I have yet to be able to reproduce the sharpness of reproduced images from the paper, I suspect it the reason is because I am decoupling the digit cap results from so that loss from the image generation is not backproped into capsnet.

Results of the decoder:

batch batch

Comparison of original and generated:

compare

Training Graphs

Results of training for 10 epochs on MNIST:

train loss train acc Test Loss train acc

Results of training for 10 epochs on Fashion MNIST:

train loss train acc Test Loss train acc

References

I found these implementations useful when I got stuck

capsulenet's People

Contributors

erikreppel avatar

Watchers

 avatar  avatar  avatar

Forkers

amilanpathirana

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.