Giter Club home page Giter Club logo

taylor-pruning's Introduction

taylor-pruning

codecov

Reproducing ICLR'17 paper Pruning Convolutional Neural Networks for Resource Efficient Inference (link).

This project is built on PyTorch.

Install

Packages required are managed by Anaconda.

# install required packages
conda env create -f environment.yml
conda activate taylor-pruning

# install this repository
pip install -e .

Usage

The project is still in progress but you may experiment with the following commands.

Data preprocessing

We can accept images from ImageNet (imagenet) and CUB-200 (cub200) datasets.

The directory of dataset should contain two sub-directories named as train and val. In either of them, all samples should be organised into subdirectories named as the labels.

Please refer to the ImageFolder class in PyTorch for more information.

Input images are transformed by the following code snippet. We randomly crop the image into 224x224 and flip them horizontally for training samples, and use single center crop for validation.

img_size = 224
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

if is_training:
  return datasets.ImageFolder(
      train_dir,
      transforms.Compose([
          transforms.RandomResizedCrop(img_size),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
          normalize,
      ]))
else:
  return datasets.ImageFolder(
      val_dir,
      transforms.Compose([
          transforms.Resize(256),
          transforms.CenterCrop(img_size),
          transforms.ToTensor(),
          normalize,
      ]))

Transfer a model

Download an ImageNet pre-trained model from TorchVision and fit its weights to the target dataset.

# VGG-16 
python transfer.py \
  -a vgg16 \
  -c PATH/TO/CHECKPOINT \
  -d cub200 \
  --dataset-dir PATH/TO/CUB-200/DATASET \
  --epochs 60 \
  --lr 1e-4 \
  --train-batch 32 \
  --wd 1e-4 \
  --pretrained \
  --gpu-id 0

# ResNet-50
python transfer.py -a resnet50 -c PATH/TO/CHECKPOINT \
  -d cub200 --dataset-dir PATH/TO/DATASET \
  --epochs 90 --schedule 30 60 --lr 1e-3 \
  --train-batch 32 --wd 1e-4 --pretrained --gpu-id 0

Above is the recipe to transfer an ImageNet pre-trained VGG-16 model to the CUB-200 dataset using the same hyperparameters as the paper: 60 epochs, 1e-4 fixed learning rate. 32 batch size and 1e-4 weight decay are deduced from the setting of the fine-tuning after pruning part of the paper.

The models transferred are listed as follows:

Model Dataset Top-1 Acc. (%) Download
VGG-16 CUB-200 76.355 link
ResNet-50 CUB-200 81.895 link

Prune a model

Run the pruning loop proposed by the paper.

python prune.py -a vgg16 -d cub200 --dataset-dir $DATASET --resume checkpoints/vgg16 -c checkpoints/pruning/vgg16 --epochs 30 --lr 1e-4 --wd 1e-4 --num-prune-iters 10 --num-channels-per-prune 100

This command runs 10 pruning iterations, and in each iteration, it prunes 100 channels of activations.

Results will be saved under the checkpoint directory provided.

Development and Troubleshooting

Please ask any problem you have through issues.

taylor-pruning's People

Contributors

kumasento avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

taylor-pruning's Issues

May remove all channels from a single layer

When selecting channels for pruning, we don't avoid removing all channels of a single layer. The possible solution is to filter the first num_channels_to_prune by some restrictive conditions, e.g., they should not contain all channels of any layer.

Export pruned model

Since we're using mask to prune, it is necessary to export the model to evaluate the correctness.

L2 Normalisation

We need to normalise the criterion values across layers (Section 2.3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.