Giter Club home page Giter Club logo

crst's Introduction

Confidence Regularized Self-Training (ICCV19, Oral)

By Yang Zou*, Zhiding Yu*, Xiaofeng Liu, Vijayakumar Bhagavatula, Jinsong Wang (* indicates equal contribution).

[Paper] [Slides] [Poster]

Update

2019-10-10: CBST/CRST pytorch code for semantic segmentation released

Contents

  1. Introduction
  2. Citation and license
  3. Requirements
  4. Results
  5. Setup
  6. Usage
  7. Note

Introduction

This repository contains the regularized self-training based methods described in the ICCV 2019 paper "Confidence Regularized Self-training". Both Class-Balanced Self-Training (CBST) and Confidence Regularized Self-Training (CRST) are implemented.

Citation and license

If you use this code, please cite:

@InProceedings{Zou_2019_ICCV,
author = {Zou, Yang and Yu, Zhiding and Liu, Xiaofeng and Kumar, B.V.K. Vijaya and Wang, Jinsong},
title = {Confidence Regularized Self-Training},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {October},
year = {2019}
}

@inproceedings{zou2018unsupervised,
  title={Unsupervised Domain Adaptation for Semantic Segmentation via Class-Balanced Self-Training},
  author={Zou, Yang and Yu, Zhiding and Kumar, BVK Vijaya and Wang, Jinsong},
  booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
  pages={289--305},
  year={2018}
}

The model and code are available for non-commercial (NC) research purposes only. If you modify the code and want to redistribute, please include the CC-BY-NC-SA-4.0 license.

Requirements:

The code is implemented based on Pytorch 0.4.0 with CUDA 9.0, OpenCV 3.2.0 and Python 2.7.12. It is tested in Ubuntu 16.04 with a single 12GB NVIDIA TiTan Xp. Maximum GPU usage is about 11GB.

Results:

  1. GTA2city:

    Case mIoU Road Sidewalk Build Wall Fence Pole Traffic Light Traffic Sign Veg. Terrain Sky Person Rider Car Truck Bus Train Motor Bike
    Source 33.35 71.71 18.53 68.02 17.37 10.15 36.63 27.63 6.27 78.66 21.80 67.69 58.28 20.72 59.26 16.43 12.45 7.93 21.21 12.96
    CBST 46.47 89.91 53.84 79.73 30.29 19.21 40.23 32.28 22.26 84.11 29.96 75.52 61.93 28.54 82.57 25.89 33.76 19.29 33.62 40.00
    CRST-LRENT 46.51 89.98 53.86 79.81 30.27 19.15 40.30 32.22 22.24 84.09 29.81 75.45 62.09 28.66 82.76 26.02 33.61 19.42 33.69 40.34
    CRST-MRKLD 47.39 91.30 55.64 80.04 30.22 18.85 39.27 35.96 27.09 84.52 31.81 74.55 62.59 27.90 82.43 23.81 31.10 25.36 32.60 45.43

Setup

We assume you are working in CRST-master folder.

  1. Datasets:
  • Download GTA5 dataset. Since GTA-5 contains images with different resolutions, we need to resize all images to 1052x1914.
  • Download Cityscapes.
  • Put downloaded data in "dataset" folder.
  1. Source pretrained models:
  • Download source model trained in GTA5 and put it into "src_model/gta5" folder.

Usage

  1. To run the self-training, you need to set the data paths of source data (data-src-dir) and target data (data-tgt-dir) by yourself. Besides that, you can keep other argument setting as default.

  2. Play with self-training for GTA2Cityscapes.

  • CBST:
sh cbst.sh
  • CRST-MRKLD:
sh mrkld.sh
  • CRST-LREND:
sh lrent.sh
  • For CBST, set "--kc-policy cb --kc-value conf". You can keep them as default.
  • Multi-scale testing are implemented in both self-training code and evaluation code. Set MST with "--test-scale".
  • We use a small class patch mining strategy to mine the patches including small classes. To turn off small class mining, set "--mine-chance 0.0".
  1. Evaluation
  • Test in Cityscapes for model compatible with GTA-5 (Initial source trained model as example). Remember to set the data folder (--data-dir).
sh evaluate.sh
  1. Train in source domain. Also remember to set the data folder (--data-dir).
  • Train in GTA-5
sh train.sh

Note

  • This code is based on DeepLab-ResNet-Pytorch.
  • The code is tested in Pytorch 0.4.0 and Python 2.7. We found running the code with other Pytorch versions will give different results. I suggest to run the code with the exact Pytorch version 0.4.0. Different performances on even 0.4.1 were reported by other users of this code.

Related Works

Contact: [email protected]

crst's People

Contributors

yzou2 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

crst's Issues

about the Eq.1 in the paper

why we use p(k|x_t, w) / λ_k in the target domain instead of only p(k|x_t, w). In my opinion, λ_k is used for select confidence threshold as in Eq.4

Questions about selecting 0 or $\hat{y}_t$ ?

Hi, thanks for your shared code.
But, I cannot find the code to select between $\hat{y}_t$ or 0 by checking which leads to a lower cost. Could you show the corresponding position ?
Thanks.

ResNet-38 for CRST

Hi, Thank you for your great work and contribution.
I am in the middle of reproducing your work with proposed backbones.
Is there a PyTorch based initial ResNet-38 model that is pre-trained on GTA and SYNTHIA to get the paper results?
Thank you again, and I hope I get a positive response.

Minor problem

weighted_prob = pred_prob/cls_thresh

Division by thresholds first does have a very small chance to change the following argmax results.
e.g. A pixel with 2 classes has softmax result as [0.89, 0.11], but the thresholds so happens to be [0.9, 0.09], then it will satisfy weighted_prob > 1, but it is not even predicted as class 2.
Maybe this is indeed the original intent of the paper, but I'll just point it out here.
I call this minor problem because the example above is very unlikely to happen in actual training.

Cityscapes and GTA dataset

Hi @yzou2 ,

Really nice work!
Just wanted to make sure that I am working with the right dataset. We need=
1)leftImg8bit_trainvaltest and gtFine_trainvaltest from cityscapes and
2)all 10 parts from GTA dataset (24966 images)?

Regularizer weight in MRKLD

Hi,

In your paper, the regularizer weight mr_weight_kld is set as 0.1 for MRKLD. But I found when calculating kld distance, you multipled the logsoftmax with another weight reg_weights which is also 0.1. So the overall weight for the regularization term is 0.1*0.1=0.01.

May I know what is this 'reg_weighs'

kld = torch.sum( -logsoftmax_val/num_class*reg_weights )

reg_ce = ce/valid_num + (mr_weight_kld*kld)/valid_reg_num

Question about train_ClsConfSet.lst

Hello, very impressive work!
I have a question about the train_ClsConfSet.lst in CRST/dataset/list/cityscapes/train_ClsConfSet.lst.
In this training list, there are only 505 lines while the whole cityscapes training list should be 2975 lines. Would you mind explaining how do you get this list? Were the results in your paper produced by this subset? Have you ever try what would happen if you use the whole set?
Thank you in advance~

Anyone reproduced the results with pytorch0.4.0, python 3.6.9, OpenCV 4.4 ?

Thanks for checking this issue.
I attempts to install python2.7 + pytorch0.4.0 on my machines, BUT FAILED EVERYTIME ...
Then I wonder:

  1. what the difference in python3.6.9 + opencv 4.4 from python2.7 + opencv 3.2?
  2. since python2.7 retires in 2020, then it is necessary for CRST can reproduce the results in python 3.6.9 ... How can we make it?

Resizing cityscape

Hi, The work is impressive. I have a queation about resizing cityscape. Should I resize cityscape into 1052x1914 just same as GTA5 do. According to the guidance you only inform us to resize GTA5.

VGG code for cbst

Thank you for the great work. Can you provide the code for VGG16? Because I want reproduce the results in your paper. Thank you.

Memoryerror

Hello,I am very interested about your project.But when I started cbst with gta2cityscapes,something was wrong.It shows cuda out of memory.And when I set batch_size =1,--mine-chance 0.0,another memoryerror happens as follow:
traceback (most recent call last):
File "crst_seg.py", line 954, in
main()
File "crst_seg.py", line 327, in main
label_2_id, valid_labels, args, logger)
File "crst_seg.py", line 464, in val
output = 0.5 * ( output + softmax2d(interp(output2)).cpu().data[0].numpy()[:,:,::-1] )
MemoryError
MY cpu Intel(R) Core(TM) i3-8100 CPU @ 3.60GHz
gpu GeForce GTX 1080ti
cuda8,pytorch 0.4.0
python 2.7.16

which mode for self training?

if args.is_training:

Thanks for your sharing! I would like to ask if you use evaluation mode here intentionally? If yes, I would like to ask the reason for this. To my understand here the condition is false and the evaluation mode is used, since I haven't seen the setting about "--is-training" in shell files.

Implementation details for Synthia

Hi,

I'm trying to get the Synthia transfer to work, and I cant see to do so. I followed the hyperparameter settings in #14 and using the source model in #24 but I cant seem to have any luck in getting close to the reported numbers. I'm currently getting about 4-5% lesser than the reported ones.

I'm using the following resizing too following #18, i use

label = cv2.resize(label, (2048, 1024), interpolation=cv2.INTER_NEAREST)                                                                                                                       
image = cv2.resize(image, (2048, 1024))   

I cant seem to exceed 38% mIoU (reported 42+). Do you know what might be missing?

Cityscape and GTA Dataset

Hi @yzou2 ,

Really nice work!
Just wanted to make sure I am working with the right dataset, we need :
1)gtFine_trainvaltest and leftImg8bit_trainvaltest from cityscapes and
2)all the 10 parts from the GTA dataset (24966 images)?

Question about GTA5 dataset training list

Hi, thanks for your impressive work.

I have a question regarding the list of GTA5 training set. As far as I know, GTA5 dataset has about 25k images, while in the training list you provide at here, there are only about 20k lines, which means that about 5k images are filtered out for training. Would you like to spend some time explaining how do you get this list and why several images are missing? Thanks in advance!

Problem reproducing CRST-MRKLD result

Hi,
I ran the mrkld.sh script with the right data paths. I did not change anything else.
The training process seems to have something wrong, The mIoU is going down drastically in every round all the way down to around 4% success.

Is there anything else needed to ne done on order to reproduce the 47% result?

How many iterations of source only training for SYNTHIA dataset ?

I'm trying to reproduce the SYNTHIA -> Cityscape task of CBST paper while GTA5 -> CityScape seemd well, using ResNet38.

But after about 30k iterations of the training in source(SYNTHIA), the mIOU of below stage is just about 3.8%.

CRST/crst_seg.py

Lines 326 to 327 in f7e4df0

conf_dict, pred_cls_num, save_prob_path, save_pred_path = val(model, device, save_round_eval_path, round_idx, tgt_num,
label_2_id, valid_labels, args, logger)

Is this mIOU the cause of insufficient source-only training ?

About train_ClsConfSet.lst

In your code, train_ClsConfSet.lst only contains 505 images. Is there any reason you didn't used 2975 images for pseudo label generation?

strange result??

When I run your crst.seg with src-model using mrkld.sh, after training with pseudo label, iou drop drastically.

image

Hyper-parameter for SYNTHIA dataset

Hi, Sorry for the successive question and Thank you again for your work.

Is it possible for you to let me know the hyper-parameters for the SYNTHIA dataset for paper results, specifically "init_src_port" and "Input_size" which are commented "for GTA" besides?

Thank you and sorry for interrupting again.

Used Models

Hi,
After reading the paper I was expecting a ResNet-38 based implementation as it yielded the better results than DeepLabV2 but if I am not mistaken there is only a DeepLabV2 based implementation. Am I missing something?

If I am not mistaken, are the logged results found on this repository were generated with DeepLabV2 training?

Thanks for this contribution.

Implementation detail request for classification task

Hi,

For implementation detail for Office-31 and Visda17, do you also follow those two steps?

  1. Pretraining on the source dataset
  2. Self-training on the target dataset

If so, could you please provide how many epochs you train for source dataset in step 1) for two datasets and how many epochs you train for step 2) ?

Best,
Chang

Reproducing the numbers given

Hi,
I'm trying to run your code to reproduce the results. However, I'm running into some issues. Using the same packages (pytorch, python versions), I'm getting the following the results.

Method Weight Result
lr_ent 0.25 44.59%
cbst 2.54%
mr_weight_kld 0.1 2.86%

I've used the defaults prescribed in the codebase itself.

Can you tell me what might be doing wrong?

problem about new dataset label convert

Hello, I trained a model based on synthia dataset and the label noted in labels_synthia.py . But when I evaluate the model with cityscapes and the label from labels.py , I found the result is absolutely wrong. Do you know what step is wrong?
frankfurt_000001_057181_leftImg8bit_color
By the way , I change the datasets.py as follow to read the label in synthia because cv2 couldn't read the label in one channel.

SYNTHIA_label_map = {3: 0, 4: 1, 2: 2, 21: 3, 5: 4, 7: 5, 15: 6, 9: 7, 6: 8, 1: 9, 10: 10, 17: 11, 8: 12, 19: 13, 12: 14, 11: 15}
#image_size = (640, 360)
def get_label_set(input):
reshape_list = list(np.reshape(input,(-1,)))
label_set = set(reshape_list)
return label_set

def read_SYNTHIA_label(label_path, kv_map):
raw_label = cv2.imread(label_path,-1)
raw_label_p = raw_label[:, :, -1]
label = raw_label_p
label_copy = 255 * np.ones(label.shape, dtype=np.float32)
for k, v in kv_map.items():
label_copy[label == k] = v #others are turned to 255
return label_copy
label = np.array(read_SYNTHIA_label(datafiles["label"], SYNTHIA_label_map), dtype=np.uint8)

MRL2 and MRENT

Hi @yzou2 ,

Will you be releasing MRL2 and MRENT codes for baseline?
I would like to reproduce the experiments from your paper.

Thanks!

Implementation of Spatial Prior

Hi,
Thank you for your excellent work.

I'm currently trying to implement spatial prior(SP) as presented in your CBST paper (at this moment, only in the pretraining), since it's apparently not implemented in this pytorch version.
After applying the SP, I guess the resulting output scores will be a few orders of magnitude less than those without applying the SP, which results in so small gradient that the training doesn't proceed well.

I've come up with several measures for this:

  1. Simply multiplying the resulting output by a random number (like, 10e+4) before calculating cross entropy loss.
  2. Normalize the output values of a pixel over the classes (via softmax or simply dividing by the sum)
  3. Increase the learning rate.

I looked it up on google, but couldn't find exact answer for my question.
Could you give me an advice on it?
Sorry for the silly question.

Thanks for your work again, and I'm happy to hear from you.

How to train our own dataset?

I want use synthetic dataset as source dataset ,and I change the path and train list.
File "/home/qiu/下载/CRST-master/deeplab/datasets.py", line 178, in getitem
img_h, img_w = label.shape
ValueError: too many values to unpack
SO is the dataset must be the same size as gta5 or some params need change?

Loss explodes

image

When training with generated pseudo label, loss explodes.
I run with hyperparameters you provided for SYNTHIA (#14)
and set my environment to required version.
Can you give any advice to solve this problem?

ValueError: Unknown format code 'f' for object of type 'str'

Hello!I successfully ran the first round, but encountered such a problem in the next round, what is the reason? Why can the first round succeed, but the following will not work? What is the reason for multiple rounds?Thank you so much!
2020-06-30 14-06-54 的屏幕截图

no mix domain

I wanted to do the cross entropy loss of the source domain and the target domain respectively, and then sum them up,backword total loss. But nan always happens. What's the matter?Due to memory constraints,i set batch_size = 1.

Reproducing the VisDa-17 results

Hi,

I am interested in reproducing the results for VisDa-17 benchmark. I couldn't find any instructions on how to do this in the repository.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.