Giter Club home page Giter Club logo

sar's Introduction

🌠 Towards Stable Test-Time Adaptation in Dynamic Wild World

This is the official project repository for Towards Stable Test-Time Adaptation in Dynamic Wild World 🔗 by Shuaicheng Niu, Jiaxiang Wu, Yifan Zhang, Zhiquan Wen, Yaofo Chen, Peilin Zhao and Mingkui Tan (ICLR 2023 Oral, Notable-Top-5%).

  • 1️⃣ SAR conducts model learning at test time to adapt a pre-trained model to test data that has distributional shifts ☀️ 🌧 ❄️, such as corruptions, simulation-to-real discrepancies, and other differences between training and testing data.
  • 2️⃣ SAR aims to adapt a model in dymamic wild world, i.e., the test data stream may have mixed domain shifts, small batch size, and online imbalanced label distribution shifts (as shown in the figure below).

wild_settings

Method: Sharpness-Aware and Reliable Entropy Minimization (SAR)

  • 1️⃣ SAR conducts selective entropy minimization by excluding partial samples with noisy gradients out of online adaptation.

  • 2️⃣ SAR optimizes both entropy and the sharpness of entropy surface simutaneously, so that the model update is robust to those remaining samples with noisy gradients.

Installation:

SAR depends on

Data preparation:

This repository contains code for evaluation on ImageNet-C 🔗 with ResNet-50 and VitBase. But feel free to use your own data and models!

  • Step 1: Download ImageNet-C 🔗 dataset from here 🔗.

  • Step 2: Put IamgeNet-C at "--data_corruption".

  • Step 3 [optional, for EATA]: Put ImageNet test/val set at "--data".

Usage:

import sar
from sam import SAM

model = TODO_model()

model = sar.configure_model(model)
params, param_names = sar.collect_params(model)
base_optimizer = torch.optim.SGD
optimizer = SAM(params, base_optimizer, lr=args.lr, momentum=0.9)
adapt_model = sar.SAR(net, optimizer, margin_e0=0.4*math.log(1000))

outputs = adapt_model(inputs)  # now it infers and adapts!

Example: Adapting a pre-trained model on ImageNet-C (Corruption).

Usage:

python3 main.py --data_corruption /path/to/imagenet-c --exp_type [normal/bs1/mix_shifts/label_shifts] --method [no_adapt/tent/eata/sar] --model [resnet50_gn_timm/vitbase_timm] --output /output/dir

'--exp_type' is choosen from:

  • 'normal' means the same test setting to prior mild data stream in Tent and EATA

  • 'bs1' means single sample adaptation, only one sample comes each time-step

  • 'mix_shifts' conducts exps over the mixture of 15 corruption types in ImageNet-C

  • 'label_shifts' means exps under online imbalanced label distribution shifts. Moreover, imbalance_ratio indicates the imbalance extent

Note: For EATA method, you need also to set "--data /path/to/imagenet" of clean ImageNet test/validation set to compute the weight importance for regularization.

Experimental results:

The Table below shows the results under online imbalanced label distribution shifts. The reported average accuracy is averaged over 15 different corruption types in ImageNet-C (severity level 5).

ResNet-50 (BN) ResNet-50 (GN) VitBase (LN)
No adapt 18.0 30.6 29.9
MEMO 24.0 31.3 39.1
DDA 27.2 35.1 36.2
Tent 2.1 22.0 47.3
EATA 0.9 31.6 49.9
SAR (ours) -- 37.2 $\pm$ 0.6 58.0 $\pm$ 0.5

Please see our PAPER 🔗 for more detailed results.

Correspondence

Please contact Shuaicheng Niu by [niushuaicheng at gmail.com] if you have any questions. 📬

Citation

If our SAR method or wild test-time adaptation settings are helpful in your research, please consider citing our paper:

@inproceedings{niu2023towards,
  title={Towards Stable Test-Time Adaptation in Dynamic Wild World},
  author={Niu, Shuaicheng and Wu, Jiaxiang and Zhang, Yifan and Wen, Zhiquan and Chen, Yaofo and Zhao, Peilin and Tan, Mingkui},
  booktitle = {Internetional Conference on Learning Representations},
  year = {2023}
}

Acknowledgment

The code is inspired by the Tent 🔗 and EATA 🔗.

sar's People

Contributors

mr-eggplant avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

sar's Issues

Have you tested the results on other models and datasets? e.g. cifar and wideresnet GN?

I tried to pretrain a widerenet28x10 on cifar10 with groupnorm, and the final ACC on clean dataset is about 93%.
Then I employed sar for adaptation on cifar10-c severity5. The adaptation error is about

[23/07/27 22:28:16] [cifar10c_group_sar.py: 109]: error % [gaussian_noise5]: 57.30%
[23/07/27 22:29:29] [cifar10c_group_sar.py: 109]: error % [shot_noise5]: 51.10%
[23/07/27 22:30:40] [cifar10c_group_sar.py: 109]: error % [impulse_noise5]: 65.13%
[23/07/27 22:31:52] [cifar10c_group_sar.py: 109]: error % [defocus_blur5]: 30.20%
[23/07/27 22:33:05] [cifar10c_group_sar.py: 109]: error % [glass_blur5]: 44.16%
[23/07/27 22:34:18] [cifar10c_group_sar.py: 109]: error % [motion_blur5]: 21.98%
[23/07/27 22:35:29] [cifar10c_group_sar.py: 109]: error % [zoom_blur5]: 28.73%
[23/07/27 22:36:42] [cifar10c_group_sar.py: 109]: error % [snow5]: 17.81%
[23/07/27 22:37:54] [cifar10c_group_sar.py: 109]: error % [frost5]: 22.26%
[23/07/27 22:39:06] [cifar10c_group_sar.py: 109]: error % [fog5]: 19.91%
[23/07/27 22:40:19] [cifar10c_group_sar.py: 109]: error % [brightness5]: 8.45%
[23/07/27 22:41:30] [cifar10c_group_sar.py: 109]: error % [contrast5]: 27.89%
[23/07/27 22:42:43] [cifar10c_group_sar.py: 109]: error % [elastic_transform5]: 27.14%
[23/07/27 22:43:55] [cifar10c_group_sar.py: 109]: error % [pixelate5]: 33.00%
[23/07/27 22:45:03] [cifar10c_group_sar.py: 109]: error % [jpeg_compression5]: 30.21%
[23/07/27 22:45:03] [cifar10c_group_sar.py: 117]: error % [mean5]: 32.35%

Is that normal? Have you tested the results on other models and datasets? Thanks so much.

Normalization used for VitBase (LN)

Hi,

thanks for sharing your work! Regarding the produced VitBase (LN) results, we noticed that there is a mismatch between the normalization you are using and the one used by timm. The timm vision transformer uses a mean of (0.5, 0.5, 0.5) and a standard deviation of (0.5, 0.5, 0.5). As a result, we got the following results for label_shifts, corresponding to Table 2 in your paper:

Acc@1 Gauss. Shot Impul. Defoc. Glass Motion Zoom Snow Frost Fog Bright. Contr. Elastic Pixel JPEG
VitBase (LN) no_adapt 46.9 47.7 46.9 42.9 34.2 50.8 44.8 57.0 52.5 56.5 76.1 31.9 46.6 65.5 66.1
TENT 58.4 59.9 59.8 58.9 56.6 62.5 59.4 66.8 24.5 70.9 79.1 63.1 65.7 73.8 71.7
SAR 59.1 60.3 60.5 59.3 57.7 62.9 59.9 67.6 66.5 70.8 79.2 63.7 66.5 74.0 71.8

All the best,
George

experiment reproduction

Hi, there.

Thanks for sharing this great work.

I download the code and try to reproduce the results in table 4. However, the results is lower than reported in paper.

due to gpu limitation, I just run two corruption.

I just run:
python3 main.py --data_corruption ./data/ImageNet-C \
--exp_type bs1 \
--method sar \
--model resnet50_gn_timm \
--output ./output/

For Resnet50-GN results in table 4. My result is:
2023-03-01 18:39:05,186 INFO : Result under shot_noise. The adaptation accuracy of SAR is top1: 22.83800 and top5: 40.58400
2023-03-01 18:39:05,198 INFO : acc1s are [20.255998611450195, 22.83799934387207]
2023-03-01 18:39:05,198 INFO : acc5s are [37.077999114990234, 40.58399963378906]

For VitBase-LN results in table 4. My result is:
2023-03-01 18:24:27,754 INFO : Result under shot_noise. The adaptation accuracy of SAR is top1: 25.73800 and top5: 44.81800
2023-03-01 18:24:27,820 INFO : acc1s are [32.34600067138672, 25.737998962402344]
2023-03-01 18:24:27,820 INFO : acc5s are [53.21399688720703, 44.817996978759766]

Did I miss something? Looking forward to your reply.

TENT performance under imbalance shift with VIT-LN model

Dear authors,
Thank you for the interesting paper,

I have cloned your repo and ran it with the following command (I do not change anything in your code):
python main.py --data_corruption ../imagenet --exp_type label_shifts --method tent --model vitbase_timm --output out.
The result of this command line should be similar to the result of the TENT model reported in Table.2 of your paper, under the VIT-base. However, I have found that the performance of TENT is not as bad as you have reported, the result when I run the first 5 corruption types is 47.8, 46.8, 48.4, 54.5, 52.3, which corresponds to gaussian_noise, shot noise, impulse noise, and defocus_blur. But the results in your paper for these corruption types are 32.7 1.4 34.6 54.4 and 52.3 ( its accuracy is similar to the two last corruption types).

Initialization of moving average of entropy

Hello author,
Thanks for the release of the code of your paper.

The (pseudo) code shows that the moving average value of entropy loss is not initialized to 0 after the model recovery.

I think the value should be initialized to 0. Could you give me an answer?

Thank you in advance:-)

experiment results of SAR is lower than reported in paper

Hi, thanks for your great work.
I clone the code and try to reproduce the results in table 2. However, the results of SAR is lower than reported in paper. (The results of TENT is the same as that in paper).

I run the experiment in 4 blur corruptions (defocus_blur, glass_blur, motion_blur, zoom_blur).

My training command is:
python main.py --data_corruption /home/cz/data/imagenet/ --exp_type label_shifts --method tent --model vitbase_timm --output ./outputs/tent
python main.py --data_corruption /home/cz/data/imagenet/ --exp_type label_shifts --method sar --model vitbase_timm --output ./outputs/sar

Results of TENT is:

2023-07-24 01:19:06,132 INFO : this exp is for label shifts, no need to shuffle the dataloader, use our pre-defined sample order
2023-07-24 01:19:06,393 INFO : imbalance ratio is 500000
2023-07-24 01:19:06,394 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 01:19:18,428 INFO : Namespace(corruption='defocus_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-19-06-tent-vitbase_timm-level5-seed2021.txt', lr=0.001, method='tent', model='vitbase_timm', output='./outputs/tent', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 01:19:18,432 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias', 'blocks.9.norm1.weight', 'blocks.9.norm1.bias', 'blocks.9.norm2.weight', 'blocks.9.norm2.bias', 'blocks.10.norm1.weight', 'blocks.10.norm1.bias', 'blocks.10.norm2.weight', 'blocks.10.norm2.bias', 'blocks.11.norm1.weight', 'blocks.11.norm1.bias', 'blocks.11.norm2.weight', 'blocks.11.norm2.bias', 'norm.weight', 'norm.bias']
2023-07-24 01:39:25,591 INFO : Result under defocus_blur. The adapttion accuracy of Tent is top1 54.37700 and top5: 77.98100
2023-07-24 01:39:25,592 INFO : acc1s are [54.37699890136719]
2023-07-24 01:39:25,592 INFO : acc5s are [77.98099517822266]
2023-07-24 01:39:25,846 INFO : imbalance ratio is 500000
2023-07-24 01:39:25,846 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 01:39:43,015 INFO : Namespace(corruption='glass_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-19-06-tent-vitbase_timm-level5-seed2021.txt', lr=0.001, method='tent', model='vitbase_timm', output='./outputs/tent', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 01:39:43,019 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias', 'blocks.9.norm1.weight', 'blocks.9.norm1.bias', 'blocks.9.norm2.weight', 'blocks.9.norm2.bias', 'blocks.10.norm1.weight', 'blocks.10.norm1.bias', 'blocks.10.norm2.weight', 'blocks.10.norm2.bias', 'blocks.11.norm1.weight', 'blocks.11.norm1.bias', 'blocks.11.norm2.weight', 'blocks.11.norm2.bias', 'norm.weight', 'norm.bias']
2023-07-24 01:59:49,538 INFO : Result under glass_blur. The adapttion accuracy of Tent is top1 52.10900 and top5: 75.50600
2023-07-24 01:59:49,538 INFO : acc1s are [54.37699890136719, 52.1089973449707]
2023-07-24 01:59:49,539 INFO : acc5s are [77.98099517822266, 75.50599670410156]
2023-07-24 01:59:49,785 INFO : imbalance ratio is 500000
2023-07-24 01:59:49,785 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 01:59:55,613 INFO : Namespace(corruption='motion_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-19-06-tent-vitbase_timm-level5-seed2021.txt', lr=0.001, method='tent', model='vitbase_timm', output='./outputs/tent', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 01:59:55,617 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias', 'blocks.9.norm1.weight', 'blocks.9.norm1.bias', 'blocks.9.norm2.weight', 'blocks.9.norm2.bias', 'blocks.10.norm1.weight', 'blocks.10.norm1.bias', 'blocks.10.norm2.weight', 'blocks.10.norm2.bias', 'blocks.11.norm1.weight', 'blocks.11.norm1.bias', 'blocks.11.norm2.weight', 'blocks.11.norm2.bias', 'norm.weight', 'norm.bias']
2023-07-24 02:20:02,500 INFO : Result under motion_blur. The adapttion accuracy of Tent is top1 58.14200 and top5: 80.62900
2023-07-24 02:20:02,500 INFO : acc1s are [54.37699890136719, 52.1089973449707, 58.141998291015625]
2023-07-24 02:20:02,500 INFO : acc5s are [77.98099517822266, 75.50599670410156, 80.62899780273438]
2023-07-24 02:20:02,756 INFO : imbalance ratio is 500000
2023-07-24 02:20:02,756 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 02:20:05,444 INFO : Namespace(corruption='zoom_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-19-06-tent-vitbase_timm-level5-seed2021.txt', lr=0.001, method='tent', model='vitbase_timm', output='./outputs/tent', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 02:20:05,448 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias', 'blocks.9.norm1.weight', 'blocks.9.norm1.bias', 'blocks.9.norm2.weight', 'blocks.9.norm2.bias', 'blocks.10.norm1.weight', 'blocks.10.norm1.bias', 'blocks.10.norm2.weight', 'blocks.10.norm2.bias', 'blocks.11.norm1.weight', 'blocks.11.norm1.bias', 'blocks.11.norm2.weight', 'blocks.11.norm2.bias', 'norm.weight', 'norm.bias']
2023-07-24 02:40:12,053 INFO : Result under zoom_blur. The adapttion accuracy of Tent is top1 52.10100 and top5: 75.84200
2023-07-24 02:40:12,053 INFO : acc1s are [54.37699890136719, 52.1089973449707, 58.141998291015625, 52.10099792480469]
2023-07-24 02:40:12,053 INFO : acc5s are [77.98099517822266, 75.50599670410156, 80.62899780273438, 75.84199523925781]

The results of SAR is:

2023-07-24 01:22:36,198 INFO : this exp is for label shifts, no need to shuffle the dataloader, use our pre-defined sample order
2023-07-24 01:22:36,478 INFO : imbalance ratio is 500000
2023-07-24 01:22:36,478 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 01:22:54,412 INFO : Namespace(corruption='defocus_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-22-36-sar-vitbase_timm-level5-seed2021.txt', lr=0.001, method='sar', model='vitbase_timm', output='./outputs/sar', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 01:22:54,416 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias']
2023-07-24 02:03:32,173 INFO : Result under defocus_blur. The adaptation accuracy of SAR is top1: 29.08600 and top5: 48.70900
2023-07-24 02:03:32,173 INFO : acc1s are [29.08599853515625]
2023-07-24 02:03:32,173 INFO : acc5s are [48.70899963378906]
2023-07-24 02:03:32,417 INFO : imbalance ratio is 500000
2023-07-24 02:03:32,417 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 02:03:40,059 INFO : Namespace(corruption='glass_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-22-36-sar-vitbase_timm-level5-seed2021.txt', lr=0.001, method='sar', model='vitbase_timm', output='./outputs/sar', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 02:03:40,063 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias']
2023-07-24 02:44:17,940 INFO : Result under glass_blur. The adaptation accuracy of SAR is top1: 23.36000 and top5: 41.38300
2023-07-24 02:44:17,941 INFO : acc1s are [29.08599853515625, 23.35999870300293]
2023-07-24 02:44:17,941 INFO : acc5s are [48.70899963378906, 41.382999420166016]
2023-07-24 02:44:18,186 INFO : imbalance ratio is 500000
2023-07-24 02:44:18,186 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 02:44:24,259 INFO : Namespace(corruption='motion_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-22-36-sar-vitbase_timm-level5-seed2021.txt', lr=0.001, method='sar', model='vitbase_timm', output='./outputs/sar', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 02:44:24,263 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias']
2023-07-24 03:25:00,744 INFO : Result under motion_blur. The adaptation accuracy of SAR is top1: 33.95400 and top5: 54.65100
2023-07-24 03:25:00,744 INFO : acc1s are [29.08599853515625, 23.35999870300293, 33.95399856567383]
2023-07-24 03:25:00,745 INFO : acc5s are [48.70899963378906, 41.382999420166016, 54.650997161865234]
2023-07-24 03:25:00,988 INFO : imbalance ratio is 500000
2023-07-24 03:25:00,988 INFO : label_shifts_indices_path is ./dataset/total_100000_ir_500000_class_order_shuffle_yes.npy
2023-07-24 03:25:03,661 INFO : Namespace(corruption='zoom_blur', d_margin=0.05, data='/dockerdata/imagenet', data_corruption='/home/cz/data/imagenet/', debug=False, e_margin=2.763102111592855, exp_type='label_shifts', fisher_alpha=2000.0, fisher_size=2000, gpu=0, if_shuffle=False, imbalance_ratio=500000, level=5, logger_name='2023-07-24-01-22-36-sar-vitbase_timm-level5-seed2021.txt', lr=0.001, method='sar', model='vitbase_timm', output='./outputs/sar', print_freq=39, sar_margin_e0=2.763102111592855, seed=2021, test_batch_size=64, workers=2)
2023-07-24 03:25:03,665 INFO : ['blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias']
2023-07-24 04:05:41,512 INFO : Result under zoom_blur. The adaptation accuracy of SAR is top1: 27.04700 and top5: 46.20700
2023-07-24 04:05:41,513 INFO : acc1s are [29.08599853515625, 23.35999870300293, 33.95399856567383, 27.046998977661133]
2023-07-24 04:05:41,513 INFO : acc5s are [48.70899963378906, 41.382999420166016, 54.650997161865234, 46.207000732421875]

SAR only get [29.1 23.5 33.9 27.0] in 4 blur corruption.

Did I miss something? Looking forward to your reply.

experiment reproduction

Hi, thanks for your great work.
I clone the code and try to reproduce the results in table 2. However, the results of SAR is lower than reported in paper. (The results of TENT is the same as that in paper).

My training command is:

Implementation details for ImageNet-R and VisDA

Hi, Dr. Niu:
Thanks for sharing this great work!
I tried to reproduce the TTA results for ImageNet-R and VisDA, but they are lower than reported in your paper.
Could you please share the pre-trained model and hyperparameters for ImageNet-R and VisDA, as in Table 10-12 in your paper?
Thank you!

Gradient norm

Hi author,
Thanks for your exciting paper!

I want to ask you how you have calculated the gradient norm in Figure 2.d. Did you run it on batch size 1 for estimating the gradient and its corresponding entropy?
Thank you!

use of model.eval() and torch.no_grad()

Hi, first of all congratulations on your great work and thanks for providing the code public!

I have couple questions about the usage of model.eval() and with torch.no_grad()

In main.py, for 'TENT' and 'EATA', you call the validate function defined on line 34, but for 'SAR' , the validation code is written inside the if state of 'SAR' from line 294 - 322. The difference is that for SAR, model.eval() and with torch.no_grad() are not used while both of them are used in the validate function.

  1. Is there particular reason why SAR does not need model.eval() and with torch.no_grad() ?
  2. Would the batch normalization parameters be optimized for 'TENT' and 'EATA' with model.eval() and torch.no_grad() ?

Thank you!

Class Imbalance Testing for CIFAR-10 Dataset

`

for myir in [10]:

# q_all denotes the label disttribution sampled at each time-step t (for example t=1,...,1000 for imagenet, each class is a time-step)
shift_proccess_name = "per_class_shift" # “per_class_shift” monotone_shift
T = 100000 # the total number of samples generated for testing (note that the samples in simulated shifted testing set may have repeated samples, or some original images missing)
dataset_name = 'imagenet1k'

if dataset_name == 'imagenet1k':
    num_classes = 1000
elif dataset_name == 'cifar10':
    num_classes = 10



if shift_proccess_name == "per_class_shift" and dataset_name == "imagenet1k":
    imbalance_ratio = myir
    shuffle_class_order = "yes"
    minor_class_prob = 1 / (imbalance_ratio + num_classes - 1)
    major_class_prob = minor_class_prob * imbalance_ratio
    q_for_all_classes = np.ones([num_classes, num_classes]) * minor_class_prob
    print(q_for_all_classes.shape)
    for i in range(num_classes):
        q_for_all_classes[i, i] = major_class_prob
    if shuffle_class_order == "yes":
        indices = list(range(num_classes))
        random.shuffle(indices)
        q_for_all_classes = q_for_all_classes[indices,:]
    def shift_proccess(T):
        num_for_repeat_each_q = T // num_classes
        assert num_for_repeat_each_q > 0, "T should greater than number of classes"
        return np.concatenate([np.expand_dims(q_for_all_classes[i,:], axis=0) for i in range(num_classes) for _ in range(num_for_repeat_each_q)], axis=0)
else:
    assert False, NotImplementedError

q_all = shift_proccess(T)

print(q_all.shape)

simulated_indices = generate_sample_indices_and_ys(q_all, dataset_name=dataset_name)

print(simulated_indices.shape)
print(simulated_indices[:100])

print(list(simulated_indices[:10]))

np.save('seed{}_total_{}_ir_{}_class_order_shuffle_{}'.format(seed, T, imbalance_ratio, shuffle_class_order), simulated_indices)

`

Why is there no implementation for Probability vector generation related to CIFAR-10 Dataset? Is there any additional algorithmic part different to Imagenet?

code for data generation

Hi,
Thank you for sharing this excellent work. Could you please share the code to generate data in './dataset/total_{}ir{}_class_order_shuffle_yes.npy'.format(100000, ir)'?
Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.