raivnlab / str Goto Github PK
View Code? Open in Web Editor NEWSoft Threshold Weight Reparameterization for Learnable Sparsity
Home Page: https://homes.cs.washington.edu/~kusupati/#Kusupati20
License: Apache License 2.0
Soft Threshold Weight Reparameterization for Learnable Sparsity
Home Page: https://homes.cs.washington.edu/~kusupati/#Kusupati20
License: Apache License 2.0
Thank you for your terrific work!
I wanna train STRvonv based resnet20 with 80%, 90% and 95% sparsity, respectively. How do I set weight decay values or Sinit to make the resultant network meet the specified sparsity requirement?
Looking forward to your reply.
Can you share the code in section 4.2
Thanks for your great work and clean code.
I run STR on ResNet18 for CIFAR10.
but the generated JSON file for sparsity says that there is almost no pruning happened:
{"module.conv1": 0.0, "module.layer1.0.conv1": 0.0, "module.layer1.0.conv2": 0.0, "module.layer1.1.conv1": 0.0, "module.layer1.1.conv2": 0.0, "module.layer2.0.conv1": 0.0, "module.layer2.0.conv2": 0.0, "module.layer2.0.downsample.0": 0.0, "module.layer2.1.conv1": 0.0, "module.layer2.1.conv2": 0.0, "module.layer3.0.conv1": 0.0, "module.layer3.0.conv2": 0.0, "module.layer3.0.downsample.0": 0.0, "module.layer3.1.conv1": 0.0, "module.layer3.1.conv2": 0.0, "module.layer4.0.conv1": 0.0, "module.layer4.0.conv2": 0.0, "module.layer4.0.downsample.0": 0.0, "module.layer4.1.conv1": 0.0, "module.layer4.1.conv2": 4.172325134277344e-05, "module.fc": 0.0, "total": 8.562441436765766e-06}
Could the choice of hyperparameters affect on this scale (No sparsity at all)? I used the same YAML config as ResNEt50 on ImageNet.
I found that you also experimented ResNet18 with CIFAR10, CIFAR100 and TinyImageNet. Can you please give me the hyperparameter choice for 90% sparsity?
https://m.youtube.com/watch?v=Hrki0p_gZKk
How do I count FLOPs
Hi, Thanks for your great work!
I just ran the evaluation code, but i found the error
...
File "/home/goqhadl9298/STR/utils/eval_utils.py", line 18, in accuracy
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
can you please check it?
Hi,
Does STR support MobileNet v2?
Thanks
I hope you are doing well today. I have run into some issues when training models using GMP.
First, when the pruning is supposed to begin at the epoch args.init_prune_epoch
, the following error is thrown from main.py:
Traceback (most recent call last):
File "main.py", line 494, in <module>
main()
File "main.py", line 43, in main
main_worker(args)
File "main.py", line 143, in main_worker
prune_decay = (1 - ((args.curr_prune_epoch - args.init_prune_epoch)/total_prune_epochs))**3
AttributeError: 'Namespace' object has no attribute 'curr_prune_epoch'
I performed a grep
search for curr_prune_epoch
and this is the only place it appears in the entire STR repo. Additionally, looking at older versions of the code I did not see any references to curr_prune_epoch
. I suspect replacing curr_prune_epoch
by epoch
will result in the intended level gradual magnitude pruning. However, even if I make this substitution there is a second issue. Even before the pruning begins (at the epoch args.init_prune_epoch
), the model does not appear to be learning. I attempted training smaller models on CIFAR-10 (Conv2/4/6 architectures from hidden-networks repo that are compatible with STR code) and after trying 30+ initializations I have not observed any model learning. I'm not sure if this is a phenomenon you have observed but based on the first issue it is unclear to me if this publicly available version GMP was successfully tested and I wanted to reach out and ask about these issues before trying to debug it any further. Below is the output of the 0th epoch for one such run (Loss at following epochs remains around 2.303):
Epoch: [0][ 0/391] Time 1.647 ( 1.647) Data 0.371 ( 0.371) Loss 3.793 (3.793) Acc@1 7.81 ( 7.81) Acc@5 47.66 ( 47.66)
Epoch: [0][ 10/391] Time 0.012 ( 0.162) Data 0.000 ( 0.034) Loss 2.303 (12.470) Acc@1 7.81 ( 10.23) Acc@5 56.25 ( 48.37)
Epoch: [0][ 20/391] Time 0.012 ( 0.091) Data 0.000 ( 0.018) Loss 2.303 (7.628) Acc@1 8.59 ( 10.16) Acc@5 50.00 ( 48.07)
Epoch: [0][ 30/391] Time 0.012 ( 0.066) Data 0.000 ( 0.013) Loss 2.303 (5.910) Acc@1 10.94 ( 10.08) Acc@5 50.78 ( 48.39)
Epoch: [0][ 40/391] Time 0.012 ( 0.053) Data 0.000 ( 0.011) Loss 2.303 (5.030) Acc@1 11.72 ( 10.33) Acc@5 60.16 ( 48.91)
Epoch: [0][ 50/391] Time 0.012 ( 0.046) Data 0.000 ( 0.009) Loss 2.303 (4.496) Acc@1 7.81 ( 9.99) Acc@5 48.44 ( 49.53)
Epoch: [0][ 60/391] Time 0.012 ( 0.041) Data 0.000 ( 0.008) Loss 2.303 (4.136) Acc@1 12.50 ( 10.11) Acc@5 46.09 ( 49.49)
Epoch: [0][ 70/391] Time 0.012 ( 0.037) Data 0.000 ( 0.007) Loss 2.303 (3.878) Acc@1 5.47 ( 9.97) Acc@5 50.00 ( 49.35)
Epoch: [0][ 80/391] Time 0.014 ( 0.034) Data 0.000 ( 0.007) Loss 2.303 (3.683) Acc@1 10.94 ( 10.03) Acc@5 49.22 ( 49.15)
Epoch: [0][ 90/391] Time 0.012 ( 0.032) Data 0.000 ( 0.007) Loss 2.303 (3.532) Acc@1 7.03 ( 10.01) Acc@5 47.66 ( 49.43)
Epoch: [0][100/391] Time 0.012 ( 0.031) Data 0.000 ( 0.006) Loss 2.303 (3.410) Acc@1 5.47 ( 10.00) Acc@5 48.44 ( 49.64)
Epoch: [0][110/391] Time 0.012 ( 0.029) Data 0.000 ( 0.006) Loss 2.303 (3.310) Acc@1 8.59 ( 9.93) Acc@5 46.88 ( 49.61)
Epoch: [0][120/391] Time 0.012 ( 0.028) Data 0.000 ( 0.006) Loss 2.303 (3.227) Acc@1 17.19 ( 10.00) Acc@5 60.94 ( 49.84)
Epoch: [0][130/391] Time 0.012 ( 0.027) Data 0.000 ( 0.006) Loss 2.303 (3.156) Acc@1 7.81 ( 10.03) Acc@5 50.00 ( 49.93)
Epoch: [0][140/391] Time 0.012 ( 0.026) Data 0.000 ( 0.005) Loss 2.303 (3.096) Acc@1 10.16 ( 9.98) Acc@5 51.56 ( 49.92)
Epoch: [0][150/391] Time 0.012 ( 0.025) Data 0.000 ( 0.005) Loss 2.303 (3.043) Acc@1 8.59 ( 10.00) Acc@5 50.00 ( 50.06)
Epoch: [0][160/391] Time 0.012 ( 0.025) Data 0.000 ( 0.005) Loss 2.303 (2.997) Acc@1 12.50 ( 10.11) Acc@5 53.91 ( 50.18)
Epoch: [0][170/391] Time 0.012 ( 0.024) Data 0.000 ( 0.005) Loss 2.303 (2.957) Acc@1 10.94 ( 10.00) Acc@5 47.66 ( 50.09)
Epoch: [0][180/391] Time 0.012 ( 0.024) Data 0.000 ( 0.005) Loss 2.303 (2.920) Acc@1 7.81 ( 10.01) Acc@5 50.00 ( 50.08)
Epoch: [0][190/391] Time 0.012 ( 0.023) Data 0.000 ( 0.005) Loss 2.303 (2.888) Acc@1 9.38 ( 10.02) Acc@5 53.12 ( 50.09)
Epoch: [0][200/391] Time 0.012 ( 0.023) Data 0.000 ( 0.005) Loss 2.303 (2.859) Acc@1 9.38 ( 10.09) Acc@5 55.47 ( 50.13)
Epoch: [0][210/391] Time 0.012 ( 0.022) Data 0.000 ( 0.005) Loss 2.303 (2.833) Acc@1 5.47 ( 10.10) Acc@5 46.09 ( 50.11)
Epoch: [0][220/391] Time 0.013 ( 0.022) Data 0.000 ( 0.005) Loss 2.303 (2.809) Acc@1 7.03 ( 10.12) Acc@5 49.22 ( 50.23)
Epoch: [0][230/391] Time 0.012 ( 0.022) Data 0.000 ( 0.004) Loss 2.303 (2.787) Acc@1 13.28 ( 10.17) Acc@5 53.12 ( 50.28)
Epoch: [0][240/391] Time 0.012 ( 0.021) Data 0.000 ( 0.004) Loss 2.303 (2.767) Acc@1 9.38 ( 10.14) Acc@5 45.31 ( 50.25)
Epoch: [0][250/391] Time 0.012 ( 0.021) Data 0.000 ( 0.004) Loss 2.303 (2.748) Acc@1 12.50 ( 10.15) Acc@5 50.78 ( 50.25)
Epoch: [0][260/391] Time 0.012 ( 0.021) Data 0.000 ( 0.004) Loss 2.303 (2.731) Acc@1 7.03 ( 10.13) Acc@5 48.44 ( 50.17)
Epoch: [0][270/391] Time 0.012 ( 0.021) Data 0.000 ( 0.004) Loss 2.303 (2.715) Acc@1 11.72 ( 10.12) Acc@5 57.03 ( 50.13)
Epoch: [0][280/391] Time 0.012 ( 0.020) Data 0.000 ( 0.004) Loss 2.303 (2.701) Acc@1 12.50 ( 10.12) Acc@5 57.81 ( 50.14)
Epoch: [0][290/391] Time 0.012 ( 0.020) Data 0.000 ( 0.004) Loss 2.303 (2.687) Acc@1 6.25 ( 10.08) Acc@5 44.53 ( 50.11)
Epoch: [0][300/391] Time 0.012 ( 0.020) Data 0.000 ( 0.004) Loss 2.303 (2.674) Acc@1 9.38 ( 10.07) Acc@5 51.56 ( 50.11)
Epoch: [0][310/391] Time 0.012 ( 0.020) Data 0.000 ( 0.004) Loss 2.303 (2.662) Acc@1 14.84 ( 10.07) Acc@5 56.25 ( 50.11)
Epoch: [0][320/391] Time 0.012 ( 0.020) Data 0.000 ( 0.004) Loss 2.303 (2.651) Acc@1 11.72 ( 10.06) Acc@5 47.66 ( 50.07)
Epoch: [0][330/391] Time 0.012 ( 0.020) Data 0.000 ( 0.004) Loss 2.303 (2.640) Acc@1 8.59 ( 10.05) Acc@5 47.66 ( 50.09)
Epoch: [0][340/391] Time 0.012 ( 0.019) Data 0.000 ( 0.004) Loss 2.303 (2.631) Acc@1 7.03 ( 10.03) Acc@5 38.28 ( 50.04)
Epoch: [0][350/391] Time 0.012 ( 0.019) Data 0.000 ( 0.004) Loss 2.303 (2.621) Acc@1 10.16 ( 10.00) Acc@5 52.34 ( 50.05)
Epoch: [0][360/391] Time 0.012 ( 0.019) Data 0.000 ( 0.004) Loss 2.303 (2.612) Acc@1 9.38 ( 10.00) Acc@5 49.22 ( 50.06)
Epoch: [0][370/391] Time 0.012 ( 0.019) Data 0.000 ( 0.004) Loss 2.303 (2.604) Acc@1 11.72 ( 10.05) Acc@5 50.00 ( 50.10)
Epoch: [0][380/391] Time 0.012 ( 0.019) Data 0.000 ( 0.004) Loss 2.303 (2.596) Acc@1 6.25 ( 10.02) Acc@5 46.09 ( 50.05)
Epoch: [0][390/391] Time 0.115 ( 0.019) Data 0.000 ( 0.004) Loss 2.303 (2.589) Acc@1 6.25 ( 10.00) Acc@5 48.75 ( 50.01)
100%|#######################################################################################################################| 391/391 [00:07<00:00, 52.66it/s]
Test: [ 0/79] Time 0.145 ( 0.145) Loss 2.303 (2.303) Acc@1 7.81 ( 7.81) Acc@5 46.09 ( 46.09)
Test: [10/79] Time 0.008 ( 0.022) Loss 2.303 (2.303) Acc@1 11.72 ( 9.80) Acc@5 47.66 ( 49.01)
Test: [20/79] Time 0.017 ( 0.017) Loss 2.303 (2.303) Acc@1 7.81 ( 9.60) Acc@5 55.47 ( 49.70)
Test: [30/79] Time 0.009 ( 0.014) Loss 2.303 (2.303) Acc@1 10.94 ( 9.80) Acc@5 44.53 ( 49.97)
Test: [40/79] Time 0.010 ( 0.014) Loss 2.303 (2.303) Acc@1 11.72 ( 10.12) Acc@5 57.81 ( 50.21)
Test: [50/79] Time 0.008 ( 0.013) Loss 2.303 (2.303) Acc@1 10.94 ( 9.80) Acc@5 52.34 ( 50.15)
Test: [60/79] Time 0.008 ( 0.013) Loss 2.303 (2.303) Acc@1 5.47 ( 9.87) Acc@5 49.22 ( 50.27)
Test: [70/79] Time 0.008 ( 0.013) Loss 2.303 (2.303) Acc@1 5.47 ( 10.01) Acc@5 50.00 ( 50.22)
100%|#########################################################################################################################| 79/79 [00:00<00:00, 81.34it/s]
Test: [79/79] Time 0.049 ( 0.013) Loss 2.303 (2.303) Acc@1 6.25 ( 10.00) Acc@5 43.75 ( 50.00)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.