Giter Club home page Giter Club logo

capsnet-pytorch's Introduction

License

CapsNet-Pytorch

Pytorch version of Hinton's paper: Dynamic Routing Between Capsules

Some implementations of CapsNet online have potential problems and it's uneasy to realize the bugs since MNIST is too simple to achieve satisfying accuracy.

Network

Corresponding pipeline: Input > Conv1 > Caps(cnn inside) > Route > Loss

Screenshots

  • running
  • training loss

Highlights

  • Highly abstraction of Caps layer, by re-writing the function create_cell_fn you can implement your own sub-network inside Caps Layer
    def create_cell_fn(self):  
            """  
            create sub-network inside a capsule.  
            :return:  
            """  
            conv1 = nn.Conv2d(self.conv1_kernel_num, self.caps1_conv_kernel_num, kernel_size = self.caps1_conv_kernel_size, stride = self.caps1_conv1_stride)  
            #relu = nn.ReLU(inplace = True)  
            #net = nn.Sequential(conv1, relu)  
            return conv1  
  • Highly abstraction of routing layer by class Route, you can take use of Caps Layer and Route Layer to construct any type of network
  • No DigitsCaps Layer, and we just just the output of Route layer.

Status

  • Currently we train our model for 30 epochs, which means it is potentially promising if more epochs are used to train
  • We don not use reconstruction loss now, and will add it later
  • The critical part of code is well commented with each dimension changes, which means you can follow the comments to understand the routing mechnism

TODO

  • add reconstruction loss
  • test on more convincing dataset, such as ImagetNet

About me

I'm a Research Assistant @ National University of Singapre, before joinging NUS, I was a first-year PhD candidate in Zhejiang University and then quitted. Contact me with email: [email protected] or wechat: dragen1860

Usage

Step 1. Install Conda, CUDA, cudnn and Pytorch

conda install pytorch torchvision cuda80 -c soumith

Step 2. Clone the repository to local

git clone https://github.com/dragen1860/CapsNet-Pytorch.git
cd CapsNet-Pytorch

Step 3. Train CapsNet on MNIST

  1. please modify the variable glo_batch_size = 125 to appropriate size according to your GPU memory size.
  2. run

$ python main.py

  1. turn on tensorboard

$ tensorboard --logdir runs

Step 4. Validate CapsNet on MNIST

OR you can comment the part of train code and test its performance with pretrained model mdl file.

Results

Model Routing Reconstruction MNIST
Baseline - - 0.39
Paper 3 no 0.35
Ours 3 no 0.34

It takes about 150s per epoch for single GTX 970 4GB Card.

Other Implementations

capsnet-pytorch's People

Contributors

dragen1860 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.