Giter Club home page Giter Club logo

pytorch-slimming's Introduction

pytorch-slimming

This is a PyTorch re-implementation of algorithm presented in "Learning Efficient Convolutional Networks Through Network Slimming (ICCV2017)." . The official source code is based on Torch. For more info, visit the author's webpage!.

CIFAR10-VGG16BN Baseline Trained with Sparsity (1e-4) Pruned (0.7 Pruned) Fine-tuned (40epochs)
Top1 Accuracy (%) 93.62 93.77 10.00 93.56
Parameters 20.04M 20.04M 2.42M 2.42M
Pruned Ratio 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7
Top1 Accuracy (%) without Fine-tuned 93.77 93.72 93.76 93.75 93.75 93.40 37.83 10.00
Parameters(M) / macc(M) 20.04/ 398.44 15.9/ 349.22 12.28/ 307.78 9.12/ 272.94 6.74/ 247.86 4.62/ 231.86 3.14/ 222.17 2.42/ 210.84
Pruned Ratio architecture
0 [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
0.1 [60, 64, 'M', 128, 128, 'M', 256, 255, 253, 245, 'M', 436, 417, 425, 462, 'M', 463, 465, 472, 424]
0.2 [58, 64, 'M', 128, 128, 'M', 256, 255, 250, 233, 'M', 360, 336, 329, 398, 'M', 420, 412, 435, 341]
0.3 [56, 64, 'M', 128, 128, 'M', 256, 254, 249, 227, 'M', 284, 239, 244, 351, 'M', 369, 364, 384, 255]
0.4 [52, 64, 'M', 128, 128, 'M', 256, 254, 247, 218, 'M', 218, 162, 166, 294, 'M', 317, 315, 318, 165]
0.5 [52, 64, 'M', 128, 128, 'M', 256, 254, 245, 214, 'M', 179, 117, 116, 229, 'M', 228, 220, 210, 111]
0.6 [51, 64, 'M', 128, 128, 'M', 256, 254, 245, 213, 'M', 165, 85, 92, 153, 'M', 83, 86, 87, 111]
0.7 [49, 64, 'M', 128, 128, 'M', 256, 254, 234, 198, 'M', 114, 41, 24, 11, 'M', 14, 13, 19, 104]

Baseline

python main.py

Trained with Sparsity

python main.py -sr --s 0.0001

Pruned

python prune.py --model model_best.pth.tar --save pruned.pth.tar --percent 0.7

Fine-tuned

python main.py -refine pruned.pth.tar --epochs 40

Reference

@InProceedings{Liu_2017_ICCV,
    author = {Liu, Zhuang and Li, Jianguo and Shen, Zhiqiang and Huang, Gao and Yan, Shoumeng and Zhang, Changshui},
    title = {Learning Efficient Convolutional Networks Through Network Slimming},
    booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
    month = {Oct},
    year = {2017}
}

pytorch-slimming's People

Contributors

foolwood avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-slimming's Issues

剪枝程序vggprune.py遇见的问题

Traceback (most recent call last):
File "vggprune.py", line 73, in
mask = weight_copy.gt(thre).float().cuda()
RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #2 'other'
我用的环境是python3.6 torch0.4.1

[59, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 470, 239, 59, 40, 'M', 0, 3, 54, 483],有一层卷积数量是0,怎么解决呢

大神您好,我在进行prune模型时,出现[59, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 470, 239, 59, 40, 'M', 0, 3, 54, 483],有一层卷积数量是0,然后就报错了,怎么解决呢?

[59, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 470, 239, 59, 40, 'M', 0, 3, 54, 483]
Traceback (most recent call last):
File "prune.py", line 111, in
newmodel = vgg(cfg=cfg)
File "/home/wxl/temp/pytorch-slimming/vgg.py", line 13, in init
self.feature = self.make_layers(cfg, True)
File "/home/wxl/temp/pytorch-slimming/vgg.py", line 30, in make_layers
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
File "/home/wxl/project/py36/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 330, in init
False, pair(0), groups, bias, padding_mode)
File "/home/wxl/project/py36/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 46, in init
self.reset_parameters()
File "/home/wxl/project/py36/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 49, in reset_parameters
init.kaiming_uniform
(self.weight, a=math.sqrt(5))
File "/home/wxl/project/py36/lib/python3.6/site-packages/torch/nn/init.py", line 314, in kaiming_uniform_
fan = _calculate_correct_fan(tensor, mode)
File "/home/wxl/project/py36/lib/python3.6/site-packages/torch/nn/init.py", line 283, in _calculate_correct_fan
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
File "/home/wxl/project/py36/lib/python3.6/site-packages/torch/nn/init.py", line 215, in _calculate_fan_in_and_fan_out
receptive_field_size = tensor[0][0].numel()
IndexError: index 0 is out of bounds for dimension 0 with size 0
@ @ @foolwood

Question about loss

Question about loss, I have a question about the loss function, in this repo, loss function is only a cross_entropy but in the paper this cross_entropy and Xγ∈Γ g(γ) , how to understand the different between implementation and the paper ?

IndexError: index 0 is out of bounds for dimension 0 with size 0' in Prune.py

Hey, thx for the code released.

I follow the steps in the READEME file, as python main.py in the first, and then, intend to run python prune.py --model model_best.pth.tar --save pruned.pth.tar --percent 0.7 but I encounter the error as below

Traceback (most recent call last):
File "prune.py", line 110, in
newmodel = vgg(cfg=cfg)
File "/home/amax/users/liuwenzhe/pruning/pytorch-slimming/vgg.py", line 13, in init
self.feature = self.make_layers(cfg, True)
File "/home/amax/users/liuwenzhe/pruning/pytorch-slimming/vgg.py", line 31, in make_layers
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
File "/home/amax/anaconda3/envs/lwz36/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 332, in init
False, pair(0), groups, bias, padding_mode)
File "/home/amax/anaconda3/envs/lwz36/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 46, in init
self.reset_parameters()
File "/home/amax/anaconda3/envs/lwz36/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 49, in reset_parameters
init.kaiming_uniform
(self.weight, a=math.sqrt(5))
File "/home/amax/anaconda3/envs/lwz36/lib/python3.6/site-packages/torch/nn/init.py", line 310, in kaiming_uniform_
fan = _calculate_correct_fan(tensor, mode)
File "/home/amax/anaconda3/envs/lwz36/lib/python3.6/site-packages/torch/nn/init.py", line 279, in _calculate_correct_fan
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
File "/home/amax/anaconda3/envs/lwz36/lib/python3.6/site-packages/torch/nn/init.py", line 211, in _calculate_fan_in_and_fan_out
receptive_field_size = tensor[0][0].numel()
IndexError: index 0 is out of bounds for dimension 0 with size 0

I checked the cfg in newmodel = vgg(cfg=cfg), and it printed as

[53, 64, 'M', 128, 128, 'M', 256, 256, 254, 212, 'M', 34, 1, 1, 0, 'M', 0, 0, 1, 263]

I guess the error may lie in 0, and confuses the nn.Conv2d().

What can I do to fix it? Is it necessary to train it again?

Thx in advance.

Model save problem

image
it seems to have problem about the prec in test() module
which lead the prec1 and thus the is_best to be constant 0 .

稀疏化训练问题

为什么我稀疏化训练的时候,BN层的gamma系数随着训练的进行,从1到-1震荡,然后又从-1到1震荡,而且所有的gamma系数都同步的震荡变化,根本不会出现离散稀疏的情况。不知道是否有人出现过这种情况?

Pruned error

Thank you for sharing, I encountered the following error in the implementation of the following code,
python prune.py --model model_best.pth.tar --save pruned.pth.tar --percent 0.7
error:
Traceback (most recent call last):
File "prune.py", line 120, in
m1.weight.data = m0.weight.data[idx1].clone()
TypeError: indexing a tensor with an object of type numpy.ndarray. The only supported types are integers, slices, numpy scalars and torch.cuda.LongTensor or torch.cuda.ByteTensor as the only argument.
I am not very familiar with Pytorch, do not know how to solve, can you help me fix this bug? Thanks for your help, I use Python 2.7.14, Pytorch 0.20

这个代码需要torch哪个版本?

我的是'1.5.0+cu92',一直报错。
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

没有过torch,这个代码有人跑通了吗?效果如何?

error:torch.load

when i reload the model_best.pth.tar, i got this error:
=> loading checkpoint './model_best.pth.tar'
Traceback (most recent call last):
File "prune.py", line 37, in
model.load_state_dict(checkpoint['state_dict'])
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "module.feature.0.weight" in state_dict'

how can i solve this problem???

稀疏训练的时候,应该指定模型吧?

python main.py -sr --s 0.0001
进行稀疏训练的时候,应该是在训练的模型之后,额外再进行训练吧?不是从头进行训练吧这个地方可以解释一下吗?有点困惑

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.