Giter Club home page Giter Club logo

knowlege-distillation's Introduction

Knowledge Distillation

This experiment started with a strong intuition that, specifying a hard probability of 100 percent to one class and 0% probability to other classes, would adversely affect the learning of shared features of each class and result in reduction of overall accuracy of a deep learning classification model, when two classes share a set of common features. To reduce this loss of accuracy, the idea was to use the soft class probabilities predicted by a large model with higher accuracy to train a simpler model and compare the results to the results obtained by training using class labels. The results showed that the accuracy of simpler models improved significantly while using results from a pretrained large network, instead of using class labels from the dataset. However, the approach was not anything new, and it was already a well researched area of deep learning and that is callled Knowledge Distillation1 2 3, though the intuition behind the approach is discussed differently in the available literature.

Architecture

Architecture

In this approach, the MSE loss is calculated on the logits before the softmax funtion. However the Kullback-Leibler divergence Loss can be used on the probability distribution after passing the result through a Softmax function.

This code base can be used for continuing experiments with Knowledge distillation. It is a simple framework for experimenting with your own loss functions in a teacher-student scenario for image classsification. You can train both teacher and student network using the framework and monitor training using Tensorboard 8.

Code base

  • Framework - PyTorch 5 and PyTorch-Lightning 6
  • Image resolution - 224
  • Datasets - Any classification dataset that supports the resolution e.g. Imagenette, Cats Vs Dogs etc. Adjust the number of classes in command line arguments
  • Most of the code should be self explanatory. Check commandline arguments for default options.
  • An nvidia docker dockerfile with necessary dependenciess is provided for creating your own docker image. Please map application and dataset volumes to train using docker.
  • Please refer to PyTorch Lightning 6 for saving and loading a checkpoint if needed in your training. This experiment was stopped prematurely due to the existence of other frameworks offering similar functionality for experimentation 4.

Usage

  • Install Nvidia Docker
  • Build docker image using the given dockerfile
  • Download datasets e.g. Fast AI Imagenette dataset 7
  • Set the required command line parameters inside commandline_args.txt
  • Run experiment.py with required command line parameters inside the docker after mapping code and dataset directories to the docker container
  • Set train_teacher flag to true for training the teacher network
  • Set distill flag to true for training student with knowledge distillation enabled
  • Make sure the teacher network training is completed before enabling distilling and training the student
  • Monitor training progress using Tensorboard
  • Default set of parameters are given in commandline_args.txt. The flags specified inside commandline_args.txt are automatically loaded during runtime

Results

The following graphs gives the validation accuracy of the same student model on the same dataset when trained with and without knowledge distillation. The higher accuracy graph shows the distillation result.

Result 1 Result 1

References

1. C. Buciluˇa, R. Caruana, and A. Niculescu-Mizil. Model compression. In Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD ’06, pages 535–541, New York, NY, USA, 2006. ACM.

2. E. Hinton, O. Vinyals, and J. Dean. Distilling the knowledge in a neural network. arXiv:1503.02531v1, Mar. 2015.

3. Seyed-Iman Mirzadeh, Mehrdad Farajtabar, Ang Li, Nir Levine, Akihiro Matsukawa, Hassan Ghasemzadeh: Improved Knowledge Distillation via Teacher Assistant. AAAI 2020: 5191-5198

4. Neta Zmora ,Guy Jacob, Lev Zlotnik, Bar Elharar, Gal NovikNeural: Network Distiller: A Python Package For DNN Compression Research, Oct. 2019

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.