Giter Club home page Giter Club logo

fixmatch-pytorch's Introduction

FixMatch-pytorch

Unofficial pytorch code for "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence," NeurIPS'20.
This implementation can reproduce the results (CIFAR10 & CIFAR100), which are reported in the paper.
In addition, it includes trained models with semi-supervised and fully supervised manners (download them on below links).

Requirements

  • python 3.6
  • pytorch 1.6.0
  • torchvision 0.7.0
  • tensorboard 2.3.0
  • pillow

Results: Classification Accuracy (%)

In addition to the results of semi-supervised learning in the paper, we also attach extra results of fully supervised learning (50000 labels, sup only) + consistency regularization (50000 labels, sup+consistency).
Consistency regularization also improves the classification accuracy, even though the labels are fully provided.
Evaluation is conducted by EMA (exponential moving average) of models in the SGD training trajectory.

CIFAR10

#Labels 40 250 4000 sup + consistency sup only
Paper (RA) 86.19 ± 3.37 94.93 ± 0.65 95.74 ± 0.05 - -
kekmodel - - 94.72 - -
valencebond 89.63(85.65) 93.08 94.72 - -
Ours 87.11 94.61 95.62 96.86 94.98
Trained Moels checkpoint checkpoint checkpoint checkpoint checkpoint

CIFAR100

#Labels 400 2500 10000 sup + consistency sup only
Paper (RA) 51.15 ± 1.75 71.71 ± 0.11 77.40 ± 0.12 - -
kekmodel - - - - -
valencebond 53.74 67.3169 73.26 - -
Ours 48.96 71.50 78.27 83.86 80.57
Trained Moels checkpoint checkpoint checkpoint checkpoint checkpoint

In the case of CIFAR100@40, the result does not reach the paper's result and is out of the confidence interval.
Despite the result, the accuracy with a small amount of labels highly depends on the label selection and other hyperparameters.
For example, we find that changing the momentum of batch normalization can give better results, closed to the reported accuracies.

Evaluation of Checkpoints

Download Checkpoints

In here, we attached some google drive links, which includes training logs and the trained models.
Because of security issues of google drive,
you may fail to download each checkpoint in the result tables by curl/wget.
Then, use gdown to download without the issues.

All checkpoints are included in this directory

Evaluation Example

After unzip the checkpoints into your own path, you can run

python eval.py --load_path saved_models/cifar10_400/model_best.pth --dataset cifar10 --num_classes 10

How to Use to Train

Important Notes

For the detailed explanations of arguments, see here.

  • In training, the model is saved at os.path.join(args.save_dir, args.save_name), after making new directory. If there already exists the path, the code will raise an error to prevent overwriting of trained models by mistake. If you want to overwrite the files, give --overwrite.
  • By default, FixMatch uses hard (one-hot) pseudo labels. If you want to use soft pseudo labels and sharping (T), give --hard_label False. Also, you can adjust the sharping parameters --T (YOUR_OWN_VALUE) .
  • This code assumes 1 epoch of training, but the number of iterations is 2**20.
  • If you restart the training, use --resume --load_path [YOUR_CHECKPOINT_PATH]. Then, the checkpoint is loaded to the model, and continues to training from the ceased iteration. see here and the related method.
  • We set the number of workers for DataLoader when distributed training with a single node having V100 GPUs x 4 is used.
  • If you change the confidence threshold to generate masks in consistency regularization, change --p_cutoff.
  • With 4 GPUs, for the fast update, running statistics of BN is not gathered in distributed training. However, a larger number of GPUs with the same batch size might affect overall accuracies. Then, you can 1) replace BN to syncBN (see here) or 2) use torch.distributed.all_reduce for BN buffers before this line.
  • We checked that syncBN slightly improves accuracies, but the training time is much increased. Thus, this code doesn't include it.

Use single GPU

python train.py --rank 0 --gpu [0/1/...] @@@other args@@@

Use multi-GPUs (with DataParallel)

python train.py --world-size 1 --rank 0 @@@other args@@@

Use multi-GPUs (with distributed training)

When you use multi-GPUs, we strongly recommend using distributed training (even with a single node) for high performance.

With V100x4 GPUs, CIFAR10 training takes about 16 hours (0.7 days), and CIFAR100 training takes about 62 hours (2.6 days).

  • single node
python train.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@
  • multiple nodes (assuming two nodes)
# at node 0
python train.py --world-size 2 --rank 0 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@
# at node 1
python train.py --world-size 2 --rank 1 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@

Run Examples (with single node & multi-GPUs)

CIFAR10

python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --save_name cifar10_4000 --dataset cifar10 --num_classes 10

CIFAR100

python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 10000 --save_name cifar100_10000 --dataset cifar100 --num_classes 100 --widen_factor 8 --weight_decay 0.001

To reproduce the results on CIFAR100, the --widen_factor has to be increased to --widen_factor=8. (see this issue in the official repo.), and --weight_decay=0.001.

Change the backbone networks

In this repo, we use WideResNet with LeakyReLU activations, implemented in models/net/wrn.py.
When you use the WideResNet, you can change widen_factor, leaky_slope, and dropRate by the argument changes.

For example,
If you want to use ReLU, just use --leaky_slope 0.0 in arugments.

Also, we support to use various backbone networks in torchvision.models.
If you want to use other backbone networks in torchvision, change the arguments
--net [MODEL's NAME in torchvision] --net_from_name True

when --net_from_name True, other model arguments are ignored except --net.

Mixed Precision Training

If you want to use mixed-precision training for speed-up, add --amp in the argument.
We checked that the training time of each iteration is reduced by about 20-30 %.

Tensorboard

We trace various metrics, including training accuracy, prefetch & run times, mask ratio of unlabeled data, and learning rates. See the details in here. You can see the metrics in tensorboard

tensorboard --logdir=[SAVE PATH] --port=[YOUR PORT]


Collaborator

fixmatch-pytorch's People

Contributors

leedoyup avatar yeongjae avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

fixmatch-pytorch's Issues

Error using single gpu for training

Thanks for the work you have done.

I encounter the following error using the single GPU training,
ValueError:num_samples should be a positive integer value, but got num_samples=-67108864

Command I am using is; python train.py --rank 0 --gpu 0

Can you please assist?

Thanks

Questions about the EMA process for BN statics

Thanks for your nice work! I am a little confused about why we need to copy the bn statics (mean/var) from the student net to the ema teacher net. I tried to remove the ema copy process for BN in the ema teacher net by
removing these two lines BN EMA Update, but leading to much worse performance. Could you please show me why this happens? Thanks!

mask and mean

안녕하세요! 작성해주신 코드 잘 보았습니다 !

아래의 링크에 걸린 부분에서, mask 를 곱해주신 뒤 mean을 이용해 평균을 내어주시는데, 이렇게 될 경우 마스크에 의해 값이 0 이 된 경우에도 평균을 낼 때 숫자가 count 되어 값이 좀 더 작아지는 현상이 있지 않나 궁금하여 질문을 드리게 되었습니다.

masked_loss = ce_loss(logits_s, pseudo_label, use_hard_labels) * mask

감사합니다 !

About the version of python

Thank you for your nice work!

I have one minor question, when I use python 3.6, it says AttributeError: module 'contextlib' has no attribute 'nullcontext'. I think python should be version 3.7?

Current weak transformation in ssl_dataset.py only correct for 32x32 images

Hey!

Thank you for the great code, it's very nice! We found one minor problem adapting it to other datasets, that probably can be fixed easily but can mess up results quite a bit (without any error):

In ssl_dataset.py line 25 you define one of the "weak" transformations as

transforms.RandomCrop(32, padding=4)

We suspect this is intended as the translation in the original code. However, for larger images this is a (probably unintentionally) strong augmentation. Should probably just be adaptive to img size or use a different transformation?

Might save others some time :)

Thanks!

Impossible to avoid including labelled data in unlabelled data

While the flag 'include_lb_to_ulb' in get_ssl_dset is meant to do this, 'ulb_dset' is set with 'data' as opposed to 'ulb_data'.

This:
ulb_dset = BasicDataset(data, targets, num_classes, transform, use_strong_transform, strong_transform, onehot)

should be:
ulb_dset = BasicDataset(ulb_data, targets, num_classes, transform, use_strong_transform, strong_transform, onehot)

What are the important factors to improve performance?

Hi, thanks for the amazing and well-documented code!

I have been using the code from valencebond, but cannot reproduce the paper's results, especially for CIFAR-100.

What do you think are the important factors to reproduce the paper's results? Is it because of some implementation details of the EMA model?

Unsupervised loss

why the upsupervised increases first a lot and decreases by very few .At last also, unsupervised loss is greater than what it was initially?

No Validation Set

Hi, For supervised and semi-supervised methods, it is generally advised to use a separate validation set. From the code, it looks like the best test set accuracy is reported.

Is there any specific reason that a separate validation set is not required.

NCCL error when train with single node & multi gpus

Hi,
I tried to run the command

python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --save_name cifar10_4000 --dataset cifar10 --num_classes 10

but the following exception was raised:

Traceback (most recent call last):
  File "train.py", line 316, in <module>
    main(args)
  File "train.py", line 62, in main
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
  File "/home/yuansong/.local/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 200, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/yuansong/.local/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/home/yuansong/.local/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 119, in join
    raise Exception(msg)
Exception:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/yuansong/.local/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/data/FixMatch-pytorch/train.py", line 155, in main_worker
    device_ids=[args.gpu])
  File "/home/yuansong/.local/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 333, in __init__
    self.broadcast_bucket_size)
  File "/home/yuansong/.local/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 549, in _distributed_broadcast_coalesced
    dist._broadcast_coalesced(self.process_group, tensors, buffer_size)
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:492, unhandled system error, NCCL version 2.4.8

I am not sure whether you met a similar problem before. If so, how did you solve it? Thanks!

how to cite your repo?

I reproduced pi model, pseudo label, meanteacher, mixmatch, uda, remixmatch, fixmatch based on your open source codes. May I ask how to cite your repo?

Training time

Nice work !! The results seem really attractive.
It takes me more than one hour to train for 10K iterations based on ONE M40 GPU, thus the total training time for ONE Cifar-10 experiment is about 2^20 / 10K * 1 h = 100 h ~ 4 days !! Is this normal ??

As you said, CIFAR10 training takes about 16 hours (0.7 days) with V100x4 GPUs. Does your training make the full utilization of all the FOUR GPUs or just based on ONE of them ??

How can i using [sup + consistency or sup only]

Hello, I just studying Fixmatch, and wondering about this [fully supervised learning (50000 labels, sup only), (50000 labels, sup+consistency).
How can i chage or using it?
what is the default set when I using [--num_labels 50000]?

thanks

Training time

The training time of 16 hours for CIFAR10 sounds like a lot – is this for all experiments summed, or a single training (with how many labeled data points)?

Whats the bottleneck? (Supervised CIFAR10 training is in the magnitude of minutes)

resuming checkpoints seems broken

[2021-01-30 22:38:52,025 INFO] 520000 iteration, USE_EMA: True, {'train/sup_loss': tensor(1.7649e-05, device='cuda:0'), 'train/unsup_loss': tensor(0.1915, device='cuda:0'), 'train/total_loss': tensor(0.1915, device='cuda:0'), 'train/mask_ratio': tensor(0.0089, device='cuda:0'), 'lr': 0.006183012144624228, 'train/prefecth_time': 0.00017820799350738524, 'train/run_time': 0.07982937622070313, 'eval/loss': tensor(0.2036, device='cuda:0'), 'eval/top-1-acc': tensor(0.9535, device='cuda:0')}, BEST_EVAL_ACC: 0.9553999900817871, at 480000 iters
[2021-01-30 22:44:52,212 INFO] model saved: ./saved_models/cifar10_4000/latest_model.pth
[2021-01-30 23:18:34,497 WARNING] USE GPU: 3 for training
[2021-01-30 23:18:34,497 WARNING] USE GPU: 2 for training
[2021-01-30 23:18:34,497 WARNING] USE GPU: 1 for training
[2021-01-30 23:18:34,501 WARNING] USE GPU: 0 for training
[2021-01-30 23:18:34,712 INFO] Number of Trainable Params: 1467610
[2021-01-30 23:18:34,816 INFO] model_arch: <models.fixmatch.fixmatch.FixMatch object at 0x2aaafa2ed0d0>
[2021-01-30 23:18:34,816 INFO] Arguments: Namespace(T=0.5, amp=True, auto_flood=False, batch_size=16, bn_momentum=0.0010000000000000009, data_dir='./data', dataset='cifar10', depth=28, dist_backend='nccl', dist_url='tcp://127.0.0.1:10001', distributed=True, dropout=0.0, ema_m=0.999, epoch=1, eval_batch_size=1024, gpu=0, hard_label=True, leaky_slope=0.1, load_path='./saved_models/cifar10_4000/latest_model.pth', lr=0.03, momentum=0.9, multiprocessing_distributed=True, net='WideResNet', net_from_name=False, num_classes=10, num_eval_iter=10000, num_labels=4000, num_train_iter=1048576, num_workers=1, overwrite=True, p_cutoff=0.95, rank=0, resume=True, save_dir='./saved_models', save_name='cifar10_4000', seed=0, train_sampler='RandomSampler', ulb_loss_ratio=1.0, uratio=7, weight_decay=0.0005, widen_factor=2, world_size=4)
[2021-01-30 23:18:37,483 INFO] [!] data loader keys: dict_keys(['train_lb', 'train_ulb', 'eval'])
[2021-01-30 23:18:38,172 INFO] Check Point Loading: train_model is LOADED
[2021-01-30 23:18:38,179 INFO] Check Point Loading: eval_model is LOADED
[2021-01-30 23:18:38,187 INFO] Check Point Loading: optimizer is LOADED
[2021-01-30 23:18:38,188 INFO] Check Point Loading: scheduler is LOADED
[2021-01-30 23:18:38,188 INFO] Check Point Loading: it is LOADED
[2021-01-30 23:22:26,994 INFO] 525000 iteration, USE_EMA: True, {'train/sup_loss': tensor(1.0072, device='cuda:0'), 'train/unsup_loss': tensor(0.1679, device='cuda:0'), 'train/total_loss': tensor(1.1751, device='cuda:0'), 'train/mask_ratio': tensor(0.7411, device='cuda:0'), 'lr': 0.023172516693920196, 'train/prefecth_time': 0.00016364799439907074, 'train/run_time': 0.08022198486328125, 'eval/loss': tensor(2.0034, device='cuda:0'), 'eval/top-1-acc': tensor(0.2563, device='cuda:0')}, BEST_EVAL_ACC: 0.2563000023365021, at 525000 iters
[2021-01-30 23:22:27,077 INFO] model saved: ./saved_models/cifar10_4000/model_best.pth
[2021-01-30 23:23:54,587 INFO] 526000 iteration, USE_EMA: True, {'train/sup_loss': tensor(0.3614, device='cuda:0'), 'train/unsup_loss': tensor(0.1964, device='cuda:0'), 'train/total_loss': tensor(0.5578, device='cuda:0'), 'train/mask_ratio': tensor(0.4464, device='cuda:0'), 'lr': 0.023147521998440744, 'train/prefecth_time': 0.00017708800733089446, 'train/run_time': 0.08024336242675781, 'eval/loss': tensor(1.2628, device='cuda:0'), 'eval/top-1-acc': tensor(0.5530, device='cuda:0')}, BEST_EVAL_ACC: 0.5529999732971191, at 526000 iters

Training time

Thanks for providing the well-documented code! It seems that every 1000 iterations taking about 5-6 mins (a single NVIDIA 2080Ti GPU). As for MixMatch, I used code here, and every 1000 iterations only take 1 min.

I agree that consistency regularization based SSL methods take much long time to train the model.
The fundamental reason is that FixMatch does not use external dataset and a pretraining model, but makes hidden representations from the unlabeled data of downstream task (CIFAR10).
In addition, FixMatch requires 2^20 iterations, which are much much longer than those of supervised learning (150 epochs with batch_size = 128, it is about 60,000 iterations).

In fact, MixMatch also uses consistency regularization and the training iterations are the same as FixMatch. What do you think caused the slow training of FixMatch compared to MixMatch?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.