Giter Club home page Giter Club logo

pytorch-deep-compression's Introduction

PyTorch Deep Compression

A PyTorch implementation of the iterative pruning method described in Han et. al. (2015) The original paper: Learning both Weights and Connections for Efficient Neural Networks

Usage

The libs package contains utilities needed, and compressor.py defines a Compressor class that allows pruning a network layer-by-layer.

The file iterative_pruning.py contains function iter_prune which achieves iterative pruning.

An example use of the function is described in the main function in the same file. Please devise your own script and do

from iterative_pruning import *

to import all necessary modules and run your script as follows.

python your_script.py [-h] [--data DIR] [--arch ARCH] [-j N] [-b N]
                            [-o O] [-m E] [-c I] [--lr LR] [--momentum M]
                            [--weight_decay W] [--resume PATH] [--pretrained]
                            [-t T [T ...]] [--cuda]

optional arguments:

  -h, --help            show this help message and exit
  --data DIR, -d DIR    path to dataset
  --arch ARCH, -a ARCH  model architecture: alexnet | densenet121 |
                        densenet161 | densenet169 | densenet201 | inception_v3
                        | resnet101 | resnet152 | resnet18 | resnet34 |
                        resnet50 | squeezenet1_0 | squeezenet1_1 | vgg11 |
                        vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19
                        | vgg19_bn
  -j N, --workers N     number of data loading workers (default: 4)
  -b N, --batch-size N  mini-batch size (default: 256)
  -o O, --optimizer O   optimizers: ASGD | Adadelta | Adagrad | Adam | Adamax
                        | LBFGS | Optimizer | RMSprop | Rprop | SGD |
                        SparseAdam (default: SGD)
  -m E, --max_epochs E  max number of epochs while training
  -c I, --interval I    checkpointing interval
  --lr LR, --learning-rate LR
                        initial learning rate
  --momentum M          momentum
  --weight_decay W, --wd W
                        weight decay
  --resume PATH         path to latest checkpoint (default: none)
  --pretrained          use pre-trained model
  -t T [T ...], --topk T [T ...]
                        Top k precision metrics
  --cuda

(other architectures in torch.vision package can also be chosen, but have not been experimented on). DATA_LOCATION should be replaced with the location of the ImageNet dataset on your machine.

Results

Model Top-1 Top-5 Compression Rate
LeNet-300-100 92% N/A 92%
LeNet-5 98.8% N/A 92%
AlexNet 39% 63% 85.99%

Note: To achieve better results, try to tweak the alpha hyper-parameter in function prune() to change the pruning rate of each layer.

Any comments, thoughts, and improvements are appreciated

pytorch-deep-compression's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch-deep-compression's Issues

Question about computation reduction

You did a very nice implement
But I want to ask for the weight that got masked by zero in weights.

Did the whole computation increase but weight's value are zero?
or the computation speed is just normal?

Momentum will change the pruned weights

Hi, thanks for the code!

According to deep compression, the pruned weights should stay zero in the later runs. However,
due to the functionality of momentum, the values of the newly pruned weights will be changed by the optimizer, since the grads of these weights aren't zero in the last epoch.

Do you have any idea to solve this problem without setting the momentum to zero?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.