Giter Club home page Giter Club logo

spherical-knowledge-distillation's Introduction

Spherical-Knowledge-Distillation

The code for implementing the SKD https://arxiv.org/abs/2010.07485

Highlight

  1. Simple to implement and fast to train. SKD adds only two lines of code onto Hinton Distillation.
  2. High accuracy. SKD can train a ResNet18 with 73% accuracy.
  3. Eases capacity gap problem. SKD can train a highly performance ResNet18 model (72.7% accuracy) with ResNet152 teacher.
  4. Very robust with temperature

This code is implemented with apex mixed precision training and dali. Apex and Dali can boost the training speed significantly. The details can be seen at https://github.com/NVIDIA/apex and https://github.com/NVIDIA/DALI. With both apex and dali, one can train ResNet18 on ImageNet in about 20 hours under 4 1080tis.

Model Release

To download the 73.01% accuracy ResNet18:

from torchvision.models.resnet import resnet18
checkpoint = 'https://github.com/forjiuzhou/Spherical-Knowledge-Distillation/releases/download/v1/resnet18_skd.pth'
model = resnet18()
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location="cpu", check_hash=True))

Requirement

pytorch dali apex

Minimal Codes

The configuration of apex and dali could be very messy. To run SKD, you can simply add two lines of code into a Hinton KD implementation, just after the model forwarding. To be noticed, the Cross Entropy loss has to use the normalized logits as input.

output = F.layer_norm(output, torch.Size((num_classes,)), None, None, 1e-7) * multiplier
output_t = F.layer_norm(output_t, torch.Size((num_classes,)), None, None, 1e-7) * multiplier

Layer normalization uses variance to normalize logits, so the appropriate multiplier can be computed by teacher's logits with torch.std(output_t, dim=1). In most cases, 'multiplier' can be set between 2 to 3. If you use F.normalize, the appropriate multiplier should be computed by torch.norm(output_t, dim=1).

Training

python main.py -a resnet18 --lr 0.01 --distillation --T=4 --epochs 100 --multiplier 2 --fp16 [imagenet-folder with train and val folders]

spherical-knowledge-distillation's People

Contributors

forjiuzhou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

spherical-knowledge-distillation's Issues

--LN argument

In Training, there is a --LN argument in the README.md for training ResNet18, but I cannot find it in the code. Besides, -a should be ResNet18, but not ResNet50. Can u address the issue? @forjiuzhou

Setting of learning rate and KD-Loss weight

I find that the setting of learning rate and kD-loss weight in your code is quite special. May I ask if it is consistent with your setting? Then please provide the specific operation parameters of your Resnet18/ Resnet50?

        if epoch < 30:
            args.alpha = 0.9
        elif epoch < 60:
            args.alpha = 0.9
        elif epoch < 80:
            args.alpha = 0.5
        elif epoch < 100:
            args.alpha = 0.1

   factor = epoch // 30
    # factor = epoch // 100
    if epoch >= 80:
        factor = factor + 1
    # if epoch >= 90:
    #     factor = factor + 1
    lr = args.lr * (0.1 ** factor)

Question about minimal code for SKD

Hi there, I've read your interesting paper. and thank you for sharing the codes.
I'm rewriting code to adequate RepDistiller format. and I have a question.

As mentioned in the paper, a Teacher's norm(l_avg) is needed to apply your idea.
then we can make our new logit (f^_i(x) * l_avg, f^_j(x) * l_avg)
but your simplified code doesn't seem to reflect these operations.

Is it replacing this operation by multiplying a constant between 2 and 3?

Please understand that I did not fully understand the paper. :)
thanks.

Replicating the Results as Reported in the Paper

Hi there @forjiuzhou,

The idea of normalising logits is an interesting one.

However, I am unable to replicate the results using the codes provided in this respo. I have set the teacher network to be ResNet50 with the student network as ResNet18.

The other change I made is to change the Pipeline from reading TFRecords to raw data for ImageNet as discussed here https://github.com/NVIDIA/DALI/blob/eb712f593d98afb87ea56700be7cfd83f512a5f8/docs/examples/use_cases/pytorch/resnet50/main.py

All other hyper-parameters are set to match those in the paper.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.