Giter Club home page Giter Club logo

pytorch-ensembles's Introduction

The official PyTorch implementation of:
Pitfalls of In-Domain Uncertainty Estimation and Ensembling in Deep Learning, ICLR'20

OpenReview / arXiv / Poster video (5 mins) / Blog / bibtex

Poster video (5 mins)

ICLR 2020 Poster Presentation by Dmitry Molchanov

Environment Setup

The following allows to create and to run a python environment with all required dependencies using miniconda:

conda env create -f condaenv.yml
conda activate megabayes

Logs, Plots, Tables, Pre-trained weights

At the notebooks folder we provide:

  • Saved logs with all computed results
  • Examples of ipython notebooks to reproduce plots, tables, and compute the deep ensemble equivalent (DEE) score

Pre-trained weights of of some models are available at link: Deep Ensembles on ImageNet (~10G) and Deep Ensembles on CIFARs (~62G), etc. The weights can be also download with a command line interface by yadisk-direct:

pip3 install wldhx.yadisk-direct

% ImageNet
curl -L $(yadisk-direct https://yadi.sk/d/rdk6ylF5mK8ptw?w=1) -o deepens_imagenet.zip
unzip deepens_imagenet.zip 

% CIFARs
curl -L $(yadisk-direct https://yadi.sk/d/8C5jBz-licWMqQ?w=1) -o deepens_cifars.zip
unzip deepens_cifars.zip 

Pre-trained weights for other models can be provided on the request---make an issue if you need some specific models.

Evaluation

The evaluation of ensembling methods can be done using the scripts from the ens folder, which contains a separate script for each ensembling method. The scrips have the following interface:

ipython -- ens/ens-<method>.py -h
usage: ens-onenet.py [-h] --dataset DATASET [--data_path PATH]
                     [--models_dir PATH] [--aug_test] [--batch_size N]
                     [--num_workers M] [--fname FNAME]

optional arguments:
  -h, --help         show this help message and exit
  --dataset DATASET  dataset name CIFAR10/CIFAR100/ImageNet (default: None)
  --data_path PATH   path to a data-folder (default: ~/data)
  --models_dir PATH  a dir that stores pre-trained models (default: ~/megares)
  --aug_test         enables test-time augmentation (default: False)
  --batch_size N     input batch size (default: 256)
  --num_workers M    number of workers (default: 10)
  --fname FNAME      comment to a log file name (default: unnamed)
  • All scripts assume that pytorch-ensembles is the current working directory (cd pytorch-ensembles).
  • The scripts will write .csv logs in pytorch-ensembles/logs in the following format rowid, dataset, architecture, ensemble_method, n_samples, metric, value, info.
  • The notebooks folder contains ipython notebooks to reproduce the tables and plots using these logs
  • The scripts will write final log-probs of every method in '.npy' format to pytorch-ensembles/.megacache folder.
  • The interface for K-FAC-Laplace differs and is described below.

Examples:

ipython -- ens/ens-onenet.py  --dataset=CIFAR10/CIFAR100/ImageNet
ipython -- ens/ens-deepens.py --dataset=CIFAR10/CIFAR100/ImageNet
ipython -- ens/ens-sse.py     --dataset=CIFAR10/CIFAR100/ImageNet
ipython -- ens/ens-csgld.py   --dataset=CIFAR10/CIFAR100
ipython -- ens/ens-fge.py     --dataset=CIFAR10/CIFAR100/ImageNet
ipython -- ens/ens-dropout.py --dataset=CIFAR10/CIFAR100
ipython -- ens/ens-vi.py      --dataset=CIFAR10/CIFAR100/ImageNet
ipython -- ens/ens-swag.py    --dataset=CIFAR10/CIFAR100

Training models

Deep Ensemble members, Variational Inference (CIFAR-10/100)

All the models trained on CIFAR use a single GPU for training. Examples of training commands:

bash train/train_cifar.sh \
--dataset CIFAR10/CIFAR100 \
--arch VGG16BN/PreResNet110/PreResNet164/WideResNet28x10 \
--method regular/vi

Fast Geometric Ensembling, SWA-Gaussian, Snapshot Ensembles, and Cyclical SGLD (CIFAR-10/100)

Examples of training commands:

bash train/train_fge.sh CIFAR10 WideResNet28x10 1 ~/weights ~/datasets
bash train/train_swag.sh CIFAR10 WideResNet28x10 1 ~/weights ~/datasets
bash train/train_sse_mcmc.sh CIFAR10 WideResNet28x10 1 ~/weights ~/datasets SSE
bash train/train_sse_mcmc.sh CIFAR10 WideResNet28x10 1 ~/weights ~/datasets cSGLD

Script parameters: dataset, architecture name, training run id, root directory for saving snapshots (created automatically), root directory for datasets (downloaded automatically)

K-FAC Laplace (CIFAR-10/100, ImageNet)

Given a checkpoint, ens/ens-kfacl.py builds the Laplace approximation and produces the results of the approximate posterior averaging. Use keys --scale_search and --gs_low LOW --gs_high HIGH --gs_num NUM to find the optimal posterior noise scale on the validation set. We have used the following noise scales (also listed in Table 3, Appendix D in the paper):

Architecture CIFAR-10 CIFAR-10-aug CIFAR-100 CIFAR-100-aug
VGG16BN 0.042 0.042 0.100 0.100
PreResNet110 0.213 0.141 0.478 0.401
PreResNet164 0.120 0.105 0.285 0.225
WideResNet28x10 0.022 0.018 0.022 0.004

For ResNet50 on ImageNet, the optimal scale found was 2.0 with test-time augmentation and 6.8 without test-time augmentation.

Refer to `ens/ens-kfacl.py' for the full list of arguments and default values. Example use:

ipython -- ens/ens-kfacl.py --file CHECKPOINT --data_path DATA --dataset CIFAR10 --model PreResNet110 --scale 0.213

Deep Ensemble members, Variational Inference, Snapshot Ensembles (ImageNet)

Examples of training commands:

bash train/train_imagenet.sh --method regular/sse/fge/vi

We strongly recommend using multi-gpu training for Snapshot Ensembles.

Attribution

Parts of this code are based on the following repositories:

Citation

If you found this code useful, please cite our paper

@article{ashukha2020pitfalls,
  title={Pitfalls of In-Domain Uncertainty Estimation and Ensembling in Deep Learning},
  author={Ashukha, Arsenii and Lyzhov, Alexander and Molchanov, Dmitry and Vetrov, Dmitry},
  journal={arXiv preprint arXiv:2002.06470},
  year={2020}
}

pytorch-ensembles's People

Contributors

alexlyzhov avatar da-molchanov avatar senya-ashukha avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-ensembles's Issues

yadisk-direct returns invalid link for CIFAR ensembles

Hi,

I cannot manage to automatically download the CIFAR ensembles since yadisk-direct https://yadi.sk/d/8C5jBz-licWMqQ?w=1 returns an invalid link.
The link neither works with curl, wget, nor manual copy-pasting into the browser.
Specifically, when trying to open the returned download link in the browser I get:

This site can’t be reached
The web page at https://downloader-default7f.disk.yandex.net/rzip/c561bf3d782fdd0b2905212678d077af348eaee82c48b898a5e1983f75106c49/648b2c54/L2xVNEEwN1Qvb3BhVlJycTZ4aXdnRXFoYWpERTNWUzJ1ZmtDMHNKWGNDRElVdStQd1VHRzBZQm5WR1A4aFNUcXEvSjZicG1SeU9Kb25UM1ZvWG5EYWc9PTo=?uid=0&filename=deepens_cifars.zip&disposition=attachment&hash=/lU4A07T/opaVRrq6xiwgEqhajDE3VS2ufkC0sJXcCDIUu%2BPwUGG0YBnVGP8hSTqq/J6bpmRyOJonT3VoXnDag%3D%3D%3A&limit=0&owner_uid=135223862&tknv=v2&rtoken=WzFX51zMIh93&force_default=no&ycrid=na-17c695b4ab2369b100660733047c9c97-downloader23h
might be temporarily down or it may have moved permanently to a new web address.
ERR_INVALID_RESPONSE

I do not have the same issue with the Imagenet ensembles. They work perfectly fine for me.
Thanks for your help!

Greetings,
Sebastian

Could you share the pretrained model for imagenet?

I think there are no pretrained model to test the performance .
Could you share the pretrained model for imagenet and cifar100?
(ex. in source, fnames='~/megares/ImageNet-ResNet50-cn-025--1564562952-1.pth.tar'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.