Giter Club home page Giter Club logo

knowledge-distillation-pytorch's Introduction

knowledge-distillation-pytorch

  • Exploring knowledge distillation of DNNs for efficient hardware solutions
  • Author: Haitong Li
  • Framework: PyTorch
  • Dataset: CIFAR-10

Features

  • A framework for exploring "shallow" and "deep" knowledge distillation (KD) experiments
  • Hyperparameters defined by "params.json" universally (avoiding long argparser commands)
  • Hyperparameter searching and result synthesizing (as a table)
  • Progress bar, tensorboard support, and checkpoint saving/loading (utils.py)
  • Pretrained teacher models available for download

Install

  • Clone the repo

    git clone https://github.com/peterliht/knowledge-distillation-pytorch.git
    
  • Install the dependencies (including Pytorch)

    pip install -r requirements.txt
    

Organizatoin:

  • ./train.py: main entrance for train/eval with or without KD on CIFAR-10
  • ./experiments/: json files for each experiment; dir for hypersearch
  • ./model/: teacher and student DNNs, knowledge distillation (KD) loss defination, dataloader

Key notes about usage for your experiments:

  • Download the zip file for pretrained teacher model checkpoints from "experiments.zip"
  • Simply move the unzipped subfolders into 'knowledge-distillation-pytorch/experiments/' (replacing the existing ones if necessary; follow the default path naming)
  • Call train.py to start training 5-layer CNN with ResNet-18's dark knowledge, or training ResNet-18 with state-of-the-art deeper models distilled
  • Use search_hyperparams.py for hypersearch
  • Hyperparameters are defined in params.json files universally. Refer to the header of search_hyperparams.py for details

Train (dataset: CIFAR-10)

Note: all the hyperparameters can be found and modified in 'params.json' under 'model_dir'

-- Train a 5-layer CNN with knowledge distilled from a pre-trained ResNet-18 model

python train.py --model_dir experiments/cnn_distill

-- Train a ResNet-18 model with knowledge distilled from a pre-trained ResNext-29 teacher

python train.py --model_dir experiments/resnet18_distill/resnext_teacher

-- Hyperparameter search for a specified experiment ('parent_dir/params.json')

python search_hyperparams.py --parent_dir experiments/cnn_distill_alpha_temp

--Synthesize results of the recent hypersearch experiments

python synthesize_results.py --parent_dir experiments/cnn_distill_alpha_temp

Results: "Shallow" and "Deep" Distillation

Quick takeaways (more details to be added):

  • Knowledge distillation provides regularization for both shallow DNNs and state-of-the-art DNNs
  • Having unlabeled or partial dataset can benefit from dark knowledge of teacher models

-Knowledge distillation from ResNet-18 to 5-layer CNN

Model Dropout = 0.5 No Dropout
5-layer CNN 83.51% 84.74%
5-layer CNN w/ ResNet18 84.49% 85.69%

-Knowledge distillation from deeper models to ResNet-18

Model Test Accuracy
Baseline ResNet-18 94.175%
+ KD WideResNet-28-10 94.333%
+ KD PreResNet-110 94.531%
+ KD DenseNet-100 94.729%
+ KD ResNext-29-8 94.788%

References

H. Li, "Exploring knowledge distillation of Deep neural nets for efficient hardware solutions," CS230 Report, 2018

Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 (2015).

Romero, A., Ballas, N., Kahou, S. E., Chassang, A., Gatta, C., & Bengio, Y. (2014). Fitnets: Hints for thin deep nets. arXiv preprint arXiv:1412.6550.

https://github.com/cs230-stanford/cs230-stanford.github.io

https://github.com/bearpaw/pytorch-classification

knowledge-distillation-pytorch's People

Contributors

akarle avatar dependabot[bot] avatar forjiuzhou avatar haitongli avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

knowledge-distillation-pytorch's Issues

An issue on loss function

I suggest both training loss function without KD and with KD should add a softmax function, because the outputs of models are without softmax. Just like this.
https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/model/net.py#L100-L114
==>
KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \ F.cross_entropy(F.softmax(outputs,dim=1), labels) * (1. - alpha)

&

https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/model/net.py#L83-L97
==>
return nn.CrossEntropyLoss()(F.softmax(outputs,dim=1), labels)

For another thing, why does the first part of the KD loss function in distill_mnist.py multiply 2?
https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/mnist/distill_mnist.py#L96-L97

One more thing, it is not necessary to multiply T*T if we distill only using soft targets.
https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/mnist/distill_mnist_unlabeled.py#L96-L97

reference
Distilling the Knowledge in a Neural Network

image

Requirements.txt is outdated?

I'm unable to run train.py on python 3.9. The versions stated in requirements are wrong, and after installing the newest libraries there's a bunch of syntax errors in the program. Is there an updated version available?

Are the distilled student models available for download?

Hello @peterliht, the pre-trained teacher models are available but do you have the corresponding student models (5 layer CNN, where Teacher Model: Resnet 18 and dataset: CIFAR 10) uploaded somewhere? If you could provide it then it would be of great help. Thanks.

Box folder

请问我在服务器上如何通过linux命令下载 box文件夹中的数据?

missing training log for base cnn

https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/experiments/base_cnn/train.log

2018-03-09 20:46:06,587:INFO: Loading the datasets...
2018-03-09 20:46:10,074:INFO: - done.
2018-03-09 20:46:10,078:INFO: Starting training for 30 epoch(s)
2018-03-09 20:51:27,485:INFO: Loading the datasets...
2018-03-09 20:51:30,918:INFO: - done.
2018-03-09 20:51:30,922:INFO: Starting training for 30 epoch(s)
2018-03-09 20:54:20,870:INFO: Loading the datasets...
2018-03-09 20:54:24,364:INFO: - done.
2018-03-09 20:54:24,368:INFO: Starting training for 30 epoch(s)
2018-03-09 20:54:24,368:INFO: Epoch 1/30

About "reduction" built in KLDivLoss

The reason why your temperature is bigger than the original paper setting (said T = 2) may be caused by KLDivLoss. You may try to set reduction = "batchmean" in KLDivLoss. Just a guess. Welcome others to discuss.

'Tensor' object is not callable

I modified the code, and I get an error, does anybody have any idea why?
I am using CPU:

I have an error in this line:

—> 10 output_teacher_batch = teacher_model(data_batch).data().numpy()
TypeError: ‘Tensor’ object is not callable

Does anybody have an idea how to solve this?

def fetch_teacher_outputs(teacher_model, dataloader):

set teacher_model to evaluation mode

teacher_model.eval()
teacher_outputs = []
for i, (data_batch, labels_batch) in enumerate(dataloader):
if torch.cuda.is_available():
data_batch, labels_batch = data_batch.cuda(async=True),
labels_batch.cuda(async=True)
data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)

**output_teacher_batch = teacher_model(data_batch).data().numpy()**
teacher_outputs.append(output_teacher_batch)

return teacher_outputs

Box Folder

I can't download the box folder.Could someone send these files to my mailbox?Thank you so much!

I think I couldn't prove how cnn_distill has highter performance than base_cnn.

This is my situation.
I trained base_cnn in advance using cifar10 dataset for comparing performance between base_cnn and cnn_distill.

Also, I trained base_resnet18 as a teacher using same dataset.
Lastly, I trained cnn_distill using resnet18.

I got two accuracy which were 0.875 from base_cnn and 0.858 from cnn_distill in each metrics_val_best_weights.json.
It looks like that base_cnn is better than cnn_distill.

I didn't change any param in base_cnn and cnn_distill except for one param which was augmentation value from 'no' to 'yes' in base_cnn's params.json.

I think there would be no reason to use knowledge-distillation if base_cnn had higher accuracy.
Please let me know where I was wrong.
Thanks for your time.

Error Cuda

Hi, this is the error I got while executing this comman, could you please check this?

python3 train.py --model_dir experiments/resnet18_distill/resnext_teacher
Loading the datasets...
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified

  • done.
    /u/halle/yeganeh/home_at/Desktop/git/knowledge-distillation-pytorch/model/resnext.py:82: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
    init.kaiming_normal(self.classifier.weight)
    /u/halle/yeganeh/home_at/Desktop/git/knowledge-distillation-pytorch/model/resnext.py:87: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
    init.kaiming_normal(self.state_dict()[key], mode='fan_out')
    THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp line=51 error=30 : unknown error
    Traceback (most recent call last):
    File "train.py", line 421, in
    teacher_model = nn.DataParallel(teacher_model).cuda()
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 260, in cuda
    return self._apply(lambda t: t.cuda(device))
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 187, in _apply
    module._apply(fn)
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/sitepython3 train.py --model_dir experiments/resnet18_distill/resnext_teacher
    Loading the datasets...
    Files already downloaded and verified
    Files already downloaded and verified
    Files already downloaded and verified
    Files already downloaded and verified
  • done.
    /u/halle/yeganeh/home_at/Desktop/git/knowledge-distillation-pytorch/model/resnext.py:82: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
    init.kaiming_normal(self.classifier.weight)
    /u/halle/yeganeh/home_at/Desktop/git/knowledge-distillation-pytorch/model/resnext.py:87: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.
    init.kaiming_normal(self.state_dict()[key], mode='fan_out')
    THCudaCheck FAIL file=/pytorch/aten/src/THC/THCGeneral.cpp line=51 error=30 : unknown error
    Traceback (most recent call last):
    File "train.py", line 421, in
    teacher_model = nn.DataParallel(teacher_model).cuda()
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 260, in cuda
    return self._apply(lambda t: t.cuda(device))
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 187, in _apply
    module._apply(fn)
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 187, in _apply
    module._apply(fn)
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 193, in _apply
    param.data = fn(param.data)
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 260, in
    return self._apply(lambda t: t.cuda(device))
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/cuda/init.py", line 162, in _lazy_init
    torch._C._cuda_init()
    RuntimeError: cuda runtime error (30) : unknown error at /pytorch/aten/src/THC/THCGeneral.cpp:51
    -packages/torch/nn/modules/module.py", line 187, in _apply
    module._apply(fn)
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 193, in _apply
    param.data = fn(param.data)
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 260, in
    return self._apply(lambda t: t.cuda(device))
    File "/u/halle/yeganeh/home_at/.local/lib/python3.6/site-packages/torch/cuda/init.py", line 162, in _lazy_init
    torch._C._cuda_init()
    RuntimeError: cuda runtime error (30) : unknown error at /pytorch/aten/src/THC/THCGeneral.cpp:51

How to improve studentmodel's acc?

My teachermodel's acc is 99%, but when I try to distill knowledge, my studentmodel's acc is under 10%. It seems that studentmodel didn't learn knowledge from teachermodel.
I use Lenet5 as my studentmodel,alpha = 0.9, temperature = 1.
Thanks for your help.

boxed folder

I have downloaded the .zip file from boxed folder, but it can't be unzipped successfully, after being unzipped, the .tar file has become .tar.cpgz file. I also have tried to unzipped through 'unzip' and 'tar xvf' through terminal on max OSX, but failed. II
屏幕快照 2020-04-16 下午10 14 28
屏幕快照 2020-04-16 下午10 14 56
屏幕快照 2020-04-16 下午10 14 41
Could you please send me the boxed folder file to my email? [email protected] Thank You!

Is student net really learn what teacher output?

I print the first 32 labels of train dataloader for teacher net and got:
14, 8, 29, 67, 59, 49, 73, 25, 4, 76, 11, 25, 82, 6, 11, 47, 28, 43, 40, 49, 27, 92, 62, 37, 64, 22, 38, 90, 14, 16, 27, 92
while the first 32 labels of train dataloader of student net but they are:
86, 40, 14, 73, 50, 43, 40, 27, 1, 51, 11, 47, 32, 76, 28, 83, 32, 4, 52, 77, 3, 64, 24, 36, 80, 93, 96, 72, 26, 75, 47, 79

So it seems that the output index of teacher net and student net are not the same at each batch.

kd loss

why softmax for teacher output , but log softmax for student output ?

teacher model in eval() mode but still update gradients?

Hi,

Very useful code and instructions! If I understand it correctly, the teacher model shouldn't be updated with gradients and only the student model will compute gradients during the distillation process. I noticed in the train_and_evaluate_kd() function, the teacher model is set to eval() mode. But I think eval() only alters the behavior of dropout or BatchNorm, it doesn't stop gradient update when loss.backward() is called. I think teacher model's parameters should set require_grad to False.

experiment result

Hello peterliht,
I ran through your code according to the instructions, did not modify any parameters, but found that the results vary greatly.
What parameters did you modify before releasing the code?
The following experimental results on resnet18:
python train.py --model_dir experiments/resnet18_distill/resnext_teacher

My experimental environment is:

python 3.5.2
pytorch 0.4.0
GPU  TITAN Xp

image

image

image

image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.