Giter Club home page Giter Club logo

wide-resnet-cifar100-kaggle's Introduction

Wide Residual Networks Using Ensemble for CIFAR100 Kaggle submission (.csv is generated)

This is taken from https://github.com/meliketoy/wide-residual-network/ and modified to create solution for STAT946 kaggle challenge - Deep Learning and gives 0.83971 on the private leaderboard: https://www.kaggle.com/c/fall2017stat946/leaderboard

NOTE: Before submitting the solution, make sure to remove "true_labels" column from the csv file generated.

All my experiments were run on GeForce GTX 1080.

Wide-residual network implementations for cifar100

Torch Implementation of Sergey Zagoruyko's Wide Residual Networks.

In order to figure out what 'width' & 'height' does on wide-residual networks, several experiments were conducted on different settings of weights and heights. It turns out that increasing the number of filters(increasing width) gave more positive influence to the model than making the model deeper.

Last but not least, simply averaging a few models with different parameter settings showed a significant increase in both top1 and top5 accuracy. The CIFAR dataset test results approached to 84.19% for CIFAR-100 with only meanstd normalization.

Requirements

See the installation instruction for a step-by-step installation guide. See the server instruction for server setup.

$ luarocks install cutorch
$ luarocks install xlua
$ luarocks install optnet

Directions and datasets

  • modelState : The best model will be saved in this directory
  • datasets : Data preparation & preprocessing directory
  • networks : Wide-residual network model structure file directory
  • gen : Generated t7 file for each dataset will be saved in this directory
  • scripts : Directory where the run file scripts are contained

Best Results

Adapting weight adjustments for each model will promise a more improved accuracy.

You can see that the ensemble network improves the results of single WRNs.

Test error (%, random flip, meanstd normaliztion, median of 5 runs) on CIFAR:

Dataset network Top1 Err(%)
CIFAR-100 WRN-28x10 18.85
CIFAR-100 Ensemble-WRN 15.81

How to run

You can train each dataset of cifar100 by editing and running the script below.

$ ./scripts/train.sh

You can test your own trained model of cifar100 by editing and running the script below. This will print test results and save the .csv file for submission to kaggle

$ ./scripts/test.sh

To train an ensemble follow the steps below.

$ vi ensemble.lua
# Press :32 in vi, which will move your cursor to line 32
ens_depth         = torch.Tensor({28, 28, 28, 28, 40, 40})
ens_widen_factor  = torch.Tensor({10, 20, 20, 20, 10, 14})
ens_nExperiment   = torch.Tensor({ 1,  1,  2,  3,  1,  1})

then edit and run

$ ./scripts/ensemble_train.sh

This will create .csv file for kaggle submission

If you want to just create the ensemble of already trained models, edit ensemble.lua as mentioned above and then edit and run scripts/ensemble_test.sh

$ ./scripts/ensemble_test.sh

Implementation Details

  • CIFAR-100
epoch learning rate weight decay Optimizer Momentum Nesterov
0 ~ 60 0.1 0.0005 Momentum 0.9 true
61 ~ 120 0.02 0.0005 Momentum 0.9 true
121 ~ 160 0.004 0.0005 Momentum 0.9 true
161 ~ 200 0.0008 0.0005 Momentum 0.9 true

CIFAR-100 Results

alt tag

Below is the result of the test set accuracy for CIFAR-100 dataset training.

NOTE: 28x20 configuration consumes the highest GPU memory ~11 GB while running on a single GPU, if you get error, it is possibly due to GPU memory insufficiency. Also note, while running script/ensemble_test.sh, you might hit into memory issues, in which case, either reduce number of models in the ensemble by modifying, ensemble.lua or edit test.lua to load models differently.

Accuracy is the average of 5 runs

network dropout preprocess GPU:0 GPU:1 per epoch Top1 acc(%) Top5 acc(%)
pre-ResNet-1001 0 meanstd - - 3 min 25 sec 77.29 93.44
wide-resnet 28x10 0 ZCA 5.90G - 2 min 03 sec 80.03 95.01
wide-resnet 28x10 0 meanstd 5.90G - 2 min 03 sec 81.01 95.44
wide-resnet 28x10 0.3 meanstd 5.90G - 2 min 03 sec 81.47 95.53
wide-resnet 28x20 0.3 meanstd 8.13G 6.93G 4 min 05 sec 82.43 96.02
wide-resnet 40x10 0.3 meanstd 8.93G - 3 min 06 sec 81.47 95.65
wide-resnet 40x14 0.3 meanstd 7.39G 6.46G 3 min 23 sec 81.83 95.50

wide-resnet-cifar100-kaggle's People

Contributors

sverneka avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

amirunpri2018

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.