Giter Club home page Giter Club logo

hsakd's Introduction

Hierarchical Self-supervised Augmented Knowledge Distillation

This project provides source code for our Hierarchical Self-supervised Augmented Knowledge Distillation (HSAKD).

This paper is publicly available at the IJCAI official proceedings: https://www.ijcai.org/proceedings/2021/0168.pdf

Our poster presentation is publicly available at 765_IJCAI_poster.pdf

Our sildes of oral presentation is publicly available at 765_IJCAI_slides.pdf

Installation

Requirements

Ubuntu 18.04 LTS

Python 3.8 (Anaconda is recommended)

CUDA 11.1

PyTorch 1.6.0

NCCL for CUDA 11.1

Perform experiments on CIFAR-100 dataset

Dataset

CIFAR-100 : download

unzip to the ./data folder

Training baselines

python train_baseline_cifar.py --arch wrn_16_2 --data ./data/  --gpu 0

More commands for training various architectures can be found in train_baseline_cifar.sh

Training teacher networks

(1) Use pre-trained backbone and train all auxiliary classifiers.

The pre-trained backbone weights follow .pth files downloaded from repositories of CRD and SSKD.

You should download them from Google Derive before training the teacher network that needs a pre-trained backbone

python train_teacher_cifar.py \
    --arch wrn_40_2_aux \
    --milestones 30 60 90 --epochs 100 \
    --checkpoint-dir ./checkpoint \
    --data ./data  \
    --gpu 2 --manual 0 \
    --pretrained-backbone ./pretrained_backbones/wrn_40_2.pth \
    --freezed

More commands for training various teacher networks with frozen backbones can be found in train_teacher_freezed.sh

The pre-trained teacher networks can be downloaded from Google Derive

(2) Train the backbone and all auxiliary classifiers jointly from scratch. In this case, we no longer need a pre-trained teacher backbone.

It can lead to a better accuracy for teacher backbone towards our empirical study.

python train_teacher_cifar.py \
    --arch wrn_40_2_aux \
    --checkpoint-dir ./checkpoint \
    --data ./data \
    --gpu 2 --manual 1

The pre-trained teacher networks can be downloaded from Google Derive

For differentiating (1) and (2), we use --manual 0 to indicate the case of (1) and --manual 1 to indicate the case of (2)

Training student networks

(1) train baselines of student networks

python train_baseline_cifar.py --arch wrn_16_2 --data ./data/  --gpu 0

More commands for training various teacher-student pairs can be found in train_baseline_cifar.sh

(2) train student networks with a pre-trained teacher network

Note that the specific teacher network should be pre-trained before training the student networks

python train_student_cifar.py \
    --tarch wrn_40_2_aux \
    --arch wrn_16_2_aux \
    --tcheckpoint ./checkpoint/train_teacher_cifar_arch_wrn_40_2_aux_dataset_cifar100_seed0/wrn_40_2_aux.pth.tar \
    --checkpoint-dir ./checkpoint \
    --data ./data \
    --gpu 0 --manual 0

More commands for training various teacher-student pairs can be found in train_student_cifar.sh

Results of the same architecture style between teacher and student networks

Teacher
Student
WRN-40-2
WRN-16-2
WRN-40-2
WRN-40-1
ResNet-56
ResNet-20
ResNet32x4
ResNet8x4
Teacher
Teacher*
76.45
80.70
76.45
80.70
73.44
77.20
79.63
83.73
Student 73.57±0.23 71.95±0.59 69.62±0.26 72.95±0.24
HSAKD 77.20±0.17 77.00±0.21 72.58±0.33 77.26±0.14
HSAKD* 78.67±0.20 78.12±0.25 73.73±0.10 77.69±0.05

Results of different architecture styles between teacher and student networks

Teacher
Student
VGG13
MobileNetV2
ResNet50
MobileNetV2
WRN-40-2
ShuffleNetV1
ResNet32x4
ShuffleNetV2
Teacher
Teacher*
74.64
78.48
76.34
83.85
76.45
80.70
79.63
83.73
Student 73.51±0.26 73.51±0.26 71.74±0.35 72.96±0.33
HSAKD 77.45±0.21 78.79±0.11 78.51±0.20 79.93±0.11
HSAKD* 79.27±0.12 79.43±0.24 80.11±0.32 80.86±0.15
  • Teacher : training teacher networks by (1).
  • Teacher* : training teacher networks by (2).
  • HSAKD : training student networks by Teacher.
  • HSAKD* : training student networks by Teacher*.

Training student networks under few-shot scenario

python train_student_few_shot.py \
    --tarch resnet56_aux \
    --arch resnet20_aux \
    --tcheckpoint ./checkpoint/train_teacher_cifar_arch_resnet56_aux_dataset_cifar100_seed0/resnet56_aux.pth.tar \
    --checkpoint-dir ./checkpoint \
    --data ./data/ \
    --few-ratio 0.25 \
    --gpu 2 --manual 0

--few-ratio: various percentages of training samples

Percentage 25% 50% 75% 100%
Student 68.50±0.24 72.18±0.41 73.26±0.11 73.73±0.10

Perform transfer experiments on STL-10 and TinyImageNet dataset

Dataset

STL-10: download

unzip to the ./data folder

TinyImageNet : download

unzip to the ./data folder

Prepare the TinyImageNet validation dataset as follows

cd data
python preprocess_tinyimagenet.py

Linear classification on STL-10

python eval_rep.py \
    --arch mobilenetV2 \
    --dataset STL-10 \
    --data ./data/  \
    --s-path ./checkpoint/train_student_cifar_tarch_vgg13_bn_aux_arch_mobilenetV2_aux_dataset_cifar100_seed0/mobilenetV2_aux.pth.tar

Linear classification on TinyImageNet

python eval_rep.py \
    --arch mobilenetV2 \
    --dataset TinyImageNet \
    --data ./data/tiny-imagenet-200/  \
    --s-path ./checkpoint/train_student_cifar_tarch_vgg13_bn_aux_arch_mobilenetV2_aux_dataset_cifar100_seed0/mobilenetV2_aux.pth.tar
Transferred Dataset CIFAR-100 → STL-10 CIFAR-100 → TinyImageNet
Student 74.66 42.57

Perform experiments on ImageNet dataset

Dataset preparation

  • Download the ImageNet dataset to YOUR_IMAGENET_PATH and move validation images to labeled subfolders

  • Create a datasets subfolder and a symlink to the ImageNet dataset

$ ln -s PATH_TO_YOUR_IMAGENET ./data/

Folder of ImageNet Dataset:

data/ImageNet
├── train
├── val

Training teacher networks

(1) Use pre-trained backbone and train all auxiliary classifiers.

The pre-trained backbone weights of ResNet-34 follow the resnet34-333f7ec4.pth downloaded from the official PyTorch: https://download.pytorch.org/models/resnet34-333f7ec4.pth

sudo python train_teacher_imagenet.py
    --dist-url 'tcp://127.0.0.1:55515' \
    --data ./data/ImageNet/ \
    --dist-backend 'nccl' \
    --multiprocessing-distributed \
    --checkpoint-dir ./checkpoint/ \
    --pretrained-backbone ./pretrained_backbones/resnet34-333f7ec4.pth \
    --freezed \
    --gpu 0,1,2,3,4,5,6,7 \
    --world-size 1 --rank 0 --manual_seed 0

(2) Train the backbone and all auxiliary classifiers jointly from scratch. In this case, we no longer need a pre-trained teacher backbone.

It can lead to a better accuracy for teacher backbone towards our empirical study.

sudo python train_teacher_imagenet.py
    --dist-url 'tcp://127.0.0.1:2222' \
    --data ./data/ImageNet/ \
    --dist-backend 'nccl' \
    --multiprocessing-distributed \
    --checkpoint-dir ./checkpoint/ \
    --gpu 0,1,2,3,4,5,6,7 \
    --world-size 1 --rank 0 --manual_seed 1

Training student networks

(1) using the teacher network of the version of a frozen backbone

sudo python train_student_imagenet.py \
    --data ./data/ImageNet/ \
    --arch resnet18_imagenet_aux \
    --tarch resnet34_imagenet_aux \
    --tcheckpoint ./checkpoint/train_teacher_imagenet_arch_resnet34_aux_dataset_imagenet_seed0/resnet34_imagenet_aux_best.pth.tar \
    --dist-url 'tcp://127.0.0.1:2222' \
    --dist-backend 'nccl' \
    --multiprocessing-distributed \
    --gpu-id 0,1,2,3,4,5,6,7 \
    --world-size 1 --rank 0 --manual_seed 0

(2) using the teacher network of the joint training version

sudo python train_student_imagenet.py \
    --data ./data/ImageNet/ \
    --arch resnet18_imagenet_aux \
    --tarch resnet34_imagenet_aux \
    --tcheckpoint ./checkpoint/train_teacher_imagenet_arch_resnet34_aux_dataset_imagenet_seed1/resnet34_imagenet_aux_best.pth.tar \
    --dist-url 'tcp://127.0.0.1:2222' \
    --dist-backend 'nccl' \
    --multiprocessing-distributed \
    --gpu-id 0,1,2,3,4,5,6,7 \
    --world-size 1 --rank 0 --manual_seed 1

Results on the teacher-student pair of ResNet-34 and ResNet-18

Accuracy Teacher Teacher* Student HSAKD HSAKD*
Top-1 73.31 75.48 69.75 72.16 72.39
Top-5 91.42 92.67 89.07 90.85 91.00
Pretrained Models resnet34_0 resnet34_1 resnet18 resnet18_0 resnet18_1

Citation

@inproceedings{yang2021hsakd,
  title={Hierarchical Self-supervised Augmented Knowledge Distillation},
  author={Chuanguang Yang, Zhulin An, Linhang Cai, Yongjun Xu},
  booktitle={Proceedings of the Thirtieth International Joint Conference on Artificial Intelligence (IJCAI)},
  pages = {1217--1223},
  year={2021}
}

hsakd's People

Contributors

winycg avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.