Giter Club home page Giter Club logo

semckd's Introduction

SemCKD

Cross-Layer Distillation with Semantic Calibration (AAAI-2021) https://arxiv.org/abs/2012.03236

The existing feature distillation works can be separated into two categories according to the position where the knowledge distillation is performed. As shown in the figure below, one is feature-map distillation and another one is feature-embedding distillation.

FD

SemCKD belongs to feature-map distillation and is compatible with SOTA feature-embedding distillation (e.g., CRD) to further boost the performance of Student Networks.

This repo contains the implementation of SemCKD together with the compared approaches, such as classic KD, Feature-Map Distillation variants like FitNet, AT, SP, VID, HKD and feature-embedding distillation variants like PKT, RKD, IRG, CC, CRD.

CIFAR-100 Results

result

where ARI means Average Relative Improvement. This evaluation metric reflects the extent to which SemCKD further improves on the basis of existing approaches compared to improvements made by these approaches upon the baseline student model.

To get the pretrained teacher models for CIFAR-100:

sh scripts/fetch_pretrained_teachers.sh

For ImageNet, pretrained models from torchvision are used, e.g. ResNet34. Save the model to ./save/models/$MODEL_vanilla/ and use scripts/model_transform.py to make it readable by our code.

Running SemCKD:

# CIFAR-100
python train_student.py --path-t ./save/models/resnet32x4_vanilla/ckpt_epoch_240.pth --distill semckd --model_s resnet8x4 -r 1 -a 1 -b 400 --trial 0
# ImageNet
python train_student.py --path-t ./save/models/ResNet34_vanilla/resnet34_transformed.pth \
--batch_size 256 --epochs 90 --dataset imagenet --gpu_id 0,1,2,3,4,5,6,7 --dist-url tcp://127.0.0.1:23333 \
--print-freq 100 --num_workers 32 --distill semckd --model_s ResNet18 -r 1 -a 1 -b 50 --trial 0 \
--multiprocessing-distributed --learning_rate 0.1 --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali gpu

Post Scripts:

  • The implementation of compared methods are based on the author-provided code and a open-source benchmark https://github.com/HobbitLong/RepDistiller. The main difference is that we set both weights for classification loss and logit-level distillation loss as 1 throughout the experiments, which is a more common practice for knowledge distillation. (-r 1 -a 1)

  • Note that the wide ResNet model in the "RepDistiller/models/wrn.py" is almost the same as those in resnet.py. For example, wrn_40_2 in wrn.py almost equals to resnet38x2 in resnet.py. The only difference is that resnet38x2 has additional three BN layers, which will lead to 2*(16+32+64)*k parameters [k=2 in this comparison].

  • Three FC layers of VGG-ImageNet are replaced with single one, thus the total layer number should be reduced by two on CIFAR-100. For example, the actual number of layers for VGG-8 is 6.

  • Computing Infrastructure:

    • For CIFAR-100, we run experiments on a single machine that contains one NVIDIA GeForce TITAN X-Pascal GPU with 12 GB of RAM at 11.4 Gbps memory speed, 32 Inter (R) Xeon (R) CPU E5-2620 v4 @ 2.10GHz. The CUDA version is 10.2. The PyTorch version is 1.0.
    • For ImageNet, we run experiments on a single machine that contains eight NVIDIA GeForce RTX 2080Ti GPUs with 11 GB of RAM at 14 Gbps memory speed, 64 Intel (R) Xeon (R) Silver 4216 CPU @ 2.10 GHz. The CUDA version is 10.2. The PyTorch version is 1.6.
  • The codes in this repository was merged from different sources, and we have not tested them thoroughly. Hence, if you have any questions, please contact us without hesitation.

Citation

If you find this repository useful, please consider citing the following paper:

@inproceedings{chen2021cross,
  author    = {Defang Chen and Jian{-}Ping Mei and Yuan Zhang and Can Wang and Zhe Wang and Yan Feng and Chun Chen},
  title     = {Cross-Layer Distillation with Semantic Calibration},
  booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
  pages     = {7028--7036},
  year      = {2021},
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.