Distillation-Based Training for Multi-Exit Architectures (Mary Phuong, Christoph H. Lampert, ICCV 2019)
We present a new method for training multi-exit architectures. A multi-exit architecture looks like this:
Early exits are classifier blocks attached to intermediate convolutional blocks. They are usually less accurate than the last exit, but faster to evaluate. Multi-exit architectures are useful for trading off accuracy for speed at test time, e.g. when the inference budget varies per example.
We propose to train such architectures by transferring knowledge from late exits () to early exits (
,
, ...), via so-called distillation, and show that (especially) early exits benefit substantially.
Read more in the paper. This repo provides code for that paper.
- Install the following (though other setups may work too):
- python 3.6.3
- torch 0.4.0
- torchvision 0.2.1
- pandas
- sacred
-
Create sub-directories
data
andsnapshots
in the repo root directory. -
Download
torchvision.datasets.CIFAR100
intodata
. You can do this by running the following from the repo root directory:
python -c 'import torchvision; torchvision.datasets.CIFAR100("./data", download=True)'
To train a multi-exit network by distillation-based training:
-
Specify hyperparameters and other options by editing the script
train_cifar.py
. (Sensible default values are provided.) -
Run
python train_cifar.py
To evaluate a trained network on test data:
-
Specify options by editing the script
eval.py
. -
Run
python eval.py