Giter Club home page Giter Club logo

bsq's Introduction

CIFAR-10 experiments

This folder contains the code for inducing mixed precision quantization schemes with BSQ on the CIFAR-10 dataset. The code for ResNet models are configured into bit representation, so as to support BSQ training and to achieve the results in the main paper.

Acknowledgement

The training and evaluation codes and the model architectures are adapted from bearpaw/pytorch-classification.

Install

Clone recursively:

git clone --recursive https://github.com/yanghr/BSQ.git

Specification of dependencies

This code is tested with Python 3.6.8, PyTorch 1.2.0 and TorchVision 0.4.0. It is recommanded to use the provided spec-file.txt file to replicate the anaconda environment used for testing this code, which can be done by:

conda create --name myenv --file spec-file.txt

We suggest using GPU to run this code for the best efficiency. Both running on a single GPU or running in parallel on multiple GPUs are supported.

Usage

Pretrained models

As introduced in Appendix A.1, pretrained models are used to initiate the BSQ training. The pretrained model are provided in the \checkpoints\cifar10\ folder, where the checkpoint in resnet-20\ is the full-precision pretrained model and the checkpoint in resnet-20-8\ is the 8-bit quantized model in bit representation.

For more details on training the full-precision model please see the training recipes provided by bearpaw/pytorch-classification. The quantized model is achieved with convert.py, which will be introduced later.

BSQ training

Here we perform BSQ training on the ResNet-20 model on the CIFAR-10 dataset.

python cifar_prune_STE.py -a resnet --depth 20 --epochs 350 --lr 0.1 --schedule 250 --gamma 0.1 --wd 1e-4 --model checkpoints/cifar10/resnet-20-8/model_best.pth.tar --decay 0.01 --Prun_Int 100 --thre 0.0 --checkpoint checkpoints/cifar10/xxx --Nbits 8 --act 4 --bin --L1 >xxx.txt

xxx in the command should be replaced with the folder you want for saving the achieved model. The achieved model will be saved in bit representation. We suggest redirecting the print output to a txt file with >xxx.txt to avoid messing up with the progress bar display and keep record of the training process.

--decay is used to set the regularization strength $$\alpha$$ in Equation (5), so as to explore the accuracy-model size tradeoff. Results for using different $$\alpha$$ are shown in Section 4.2.

--Prun_Int is the number of epochs between each re-quantization and precision adjustment step, which is suggested to be set to 100. The effect of using other intervals are illustrated in Appendix B.1.

--act indicates the quantization precision of the activation in the model. Should be kept the same in BSQ training and finetuning. Default value set to 4.

Evaluating and finetuning achieved model

The model achieved from BSQ training can be evaluated and finetuned with cifar_finetune.py.

For evaluation, run

python cifar_finetune.py -a resnet --depth 20 --model checkpoints/cifar10/xxx/checkpoint.pth.tar --Nbits 8 --act 4 --bin --evaluate

xxx in the command should be replaced with the folder used to save the BSQ trained model. Note that only model in bit representation can be evaluated in this way. The testing accuracy, the precentage of 1s in each bit of each layer's weight and the precision assigned to each layer will be printed in the output.

To further finetune the ahcieved model, use

python cifar_finetune.py -a resnet --depth 20 --epochs 300 --lr 0.01 --schedule 150 250 --gamma 0.1 --wd 1e-4 --model checkpoints/cifar10/xxx/checkpoint.pth.tar --checkpoint checkpoints/cifar10/xxx-ft --Nbits 8 --act 4 --bin >xxx-ft.txt

The quantization scheme will be fixed throughout the finetuning process. At the end of finetuning, the model with the highest testing accuracy will be stored in both bit representation and floating-point weights. The bit representation is saved in checkpoints/cifar10/xxx-ft/best_bin.pth.tar and the floating-point model is saved in checkpoints/cifar10/xxx-ft/best_float.pth.tar

Converting full-precision models to bit representation with achieved quantization schemes

To convert a full-precision model, use

python convert.py -a resnet --depth 20 --model checkpoints/cifar10/resnet-20/model_best.pth.tar --dict checkpoints/cifar10/xxx/checkpoint.pth.tar --checkpoint checkpoints/cifar10/xxx-mp --Nbits 8 --act 4 >xxx-mp.txt

If the path in --dict is provided, the model will be converted to the same quantization scheme as the model specified in --dict. Otherwise the whole model will be quantized to the precision specified in --Nbits. The converted model will be in bit representation, and will be saved in the folder specified in --checkpoint. We use this code to achieve the 8-bit quantized model before BSQ training, and to achieve the "train from scratch" models that are further finetuned to be compared in Table 1.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.