Giter Club home page Giter Club logo

semi-supervised-segmentation-cyclegan's Introduction

Revisting Cycle-GAN for semi-supervised segmentation

This repo contains the official Pytorch implementation of the paper: Revisiting CycleGAN for semi-supervised segmentation

Contents

  1. Summary of the Model
  2. Setup instructions and dependancies
  3. Repository Overview
  4. Running the model
  5. Some results of the paper
  6. Contact
  7. License

1. Summary of the Model

The following shows the training procedure for our proposed model

We propose a training procedure for semi-supervised segmentation using the principles of image-to-image translation using GANs. The proposed procedure has been evaluated on three segmentation datasets, namely VOC, Cityscapes, ACDC. We are easily able to achieve 2-4% improvement in the mean IoU for all of our semisupervised model as compared to the supervised model on the same amount of data. For further information regarding the model, training procedure details you may refer to the paper for further details.

2. Setup instructions and dependancies

The code has been written in Python 3.6 and Pytorch v1.0 with Torchvision v0.3. You can install all the dependancies for the model by running the following command in a virtual environment

pip install -r requirements.txt

For training/testing the model, you must first download all the 3 datasets. You can download the processed version of the dataset here. Also for storing the results of the validation/testing datasets, checkpoints and tensorboard logs, the directory structure must in the following way:

.
├── arch
│   ├── ...
├── data                     # Follow the way the dataset has been placed here         
│   ├── ACDC                 # Here the ACDC dataset must be placed
│   └── Cityscape            # Here the Cityscapes dataset must be placed
│   └── VOC2012              # Here the VOC train/val dataset must be placed
│   └── VOC2012test          # Here the VOC test dataset must be placed
├── data_utils
│   ├── ...
├── checkpoints              # Create this directory to store checkpoints   
├── examples                 
├── main.py                  
├── model.py
├── README.md
├── testing.py
├── utils.py
├── validation.py
├── results                  # Create this directory for storing the results
│   ├── supervised           # Directory for storing supervised results  
│   └── unsupervised         # Directory for storing semisupervised results
└── tensorboard_results      # Create this directory to store tensorboard log curves

3. Repository Overview

The following are the information regarding the various important files in the directory and their function:

  • arch : The directory stores the architectures for the generators and discriminators used in our model
  • data_utils : The dataloaders and also helper functions for data processing
  • main.py : Stores the various hyperparameter information and default settings
  • model.py : Stores the training procedure for both supervised and semisupervised model, and also checkpointing and logging details
  • utils.py : Stores the helper functions required for training

4. Running the model

You configure the various defaults that are being specified in the main.py file. And also modify the supervision percentage on the dataset by modifying the dataloader calling function in the model.py file.

For training/validation/testing the our proposed semisupervised model:

python main.py --model 'semisupervised_cycleGAN' --dataset 'voc2012' --gpu_ids '0' --training True    

Similar commands for the validation and testing can be put up by replacing --training with --validation and --testing respectively.

5. Some results of the paper

Some of the results produced by our semisupervised model are as follows. For more such results, consider seeing the main paper and also its supplementary section

The generated labels are the labels obtained from semisupervised model while the generated image are the images obtained by passing labels from label to image network

6. Contact

If you have found our research work helpful, please consider citing the original paper.

If you have any doubt regarding the codebase, you can open up an issue or mail at [email protected] / [email protected]

7. License

This repository is licensed under MIT license

semi-supervised-segmentation-cyclegan's People

Contributors

aniket-agarwal1999 avatar arnab39 avatar gedomech avatar josedolz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

semi-supervised-segmentation-cyclegan's Issues

network channel problem when use the ACDC dataset

Hi,
when I use the code and ACDC dataset your provide, I directly run the commands python3 main.py --model 'semisupervised_cycleGAN' --dataset 'acdc' --gpu_ids '6,7' --training True --batch_size 2, but I got this problem:

Training semi-supervised cycleGAN
------------Number of Parameters---------------
[Network Gsi] Total number of parameters : 42.795 M
-----------------------------------------------
------------Number of Parameters---------------
[Network Gis] Total number of parameters : 42.724 M
[Network Gsi] Total number of parameters : 42.795 M
[Network Di] Total number of parameters : 2.765 M
[Network Ds] Total number of parameters : 2.766 M
-----------------------------------------------
 [*] No checkpoint!
Found 808 label images
Found 808 unlabel images
Found 286 val images
learning rate = 0.0002000
Traceback (most recent call last):
  File "main.py", line 87, in <module>
    main()
  File "main.py", line 73, in main
    model.train(args)
  File "/home/h/paper_code/code_test/Semi-supervised-segmentation-cycleGAN1/model.py", line 361, in train
    fake_gt = self.Gsi(unl_img.float())  ### having 21 channels
  File "/home/h/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/h/paper_code/code_test/Semi-supervised-segmentation-cycleGAN1/arch/generators.py", line 431, in forward
    x = self.conv1(x)
  File "/home/h/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/h/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 345, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/h/.local/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 3 7 7, expected input[2, 1, 256, 256] to have 3 channels, but got 1 channels instead

It seems like the input channel of the network is 3, but the ACDC img's channel is 1. How should I modify the code? Thank you very much.

Cityscapes

Training with VOC2012 dataset works fine for me. However, when I change to the cityscapes dataset it seems to have some issue and the machine runs out of memory quickly: This is the exact error message:

RuntimeError: CUDA out of memory. Tried to allocate 34.00 MiB (GPU 0; 23.65 GiB total capacity; 20.35 GiB already allocated; 93.06 MiB free; 22.01 GiB reserved in total by PyTorch)

I am training with a Titan RTX and CUDA 10.1 in virtualenv with the requirements installed given by the github.

Does anyone know how to fix this issue?

Channel issue with ACDC dataset

I downloaded the provided ACDC dataset and when I tried to run the code, I got the following error (channels issue)

Traceback (most recent call last):
File "main.py", line 87, in
main()
File "main.py", line 73, in main
model.train(args)
File "/project/6037231/amitoj/cycleGAN/model.py", line 362, in train
fake_gt = self.Gsi(unl_img.float()) ### having 21 channels
File "/home/amitoj/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, **kwargs)
File "/project/6037231/amitoj/cycleGAN/arch/generators.py", line 431, in forward
x = self.conv1(x)
File "/home/amitoj/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in call
result = self.forward(*input, **kwargs)
File "/home/amitoj/.local/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 343, in forward
return self.conv2d_forward(input, self.weight)
File "/home/amitoj/.local/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 340, in conv2d_forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 3 7 7, expected input[5, 1, 256, 256] to have 3 channels, but got 1 channels instead

Any fix for this or do I need to manually edit all the code?

Why does 'Sample_from_Pool' work?

Hi,
In your project, during the training of D, 'Sample_from_Pool' is used. What is the difference of training D without 'Sample_from_Pool'? Could you give some explaintions?

ACDC dataset

Hello, I found that the ACDC training_gt dataset is corrupted while downloading. Could you please re-upload the dataset?

ckpt_for_Arnab_loss.ckpt

dear master
I found the file ckpt_for_Arnab_loss.ckpt is lost ,can you share it on the github?
Thank you very much!

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Hello, arnab39!
I have a problem for you.
When I run main.py, but I get this error:

Traceback (most recent call last):
File "main.py", line 88, in
main()
File "main.py", line 73, in main
model.train(args)
File "/home/lix/Documents/code/Semi-supervised-segmentation-cycleGAN/model.py", line 443, in train
gen_loss.backward()
File "/home/lix/miniconda3/envs/ptseg/lib/python3.6/site-packages/torch/tensor.py", line 107, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/lix/miniconda3/envs/ptseg/lib/python3.6/site-packages/torch/autograd/init.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

How can I solve this problem?
I am looking forward for your reply!
Thanks!

Hello!Code cannot be trained

During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "F:/SemicycleGAN0/Semi-supervised-segmentation-cycleGAN-master/main.py", line 87, in
main()
File "F:/SemicycleGAN0/Semi-supervised-segmentation-cycleGAN-master/main.py", line 73, in main
model.train(args)
File "F:\SemicycleGAN0\Semi-supervised-segmentation-cycleGAN-master\model.py", line 439, in train
fake_gt_discriminator = make_one_hot(fake_gt_discriminator, args.dataset, args.gpu_ids)
File "F:\SemicycleGAN0\Semi-supervised-segmentation-cycleGAN-master\utils.py", line 347, in make_one_hot
one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

The output of x is nan.

Hello, arnab39!
I have a problem for you.

When I run the main.py and use the resnet101 as the generators,
x = self.layer3(x) (in generators.py)
The output of x is nan.

How can I solve this problem?
I am looking forward for your reply!
Thanks!

I modified your code to train my crack data set and encountered some problems

My dataset n_channel = 2 (cracks or cracks), I like your code training cityscapes in the code, change the n_channels to 2, but the program is running to The Times wrong one_hot tensor input to the model (pictured), I find a solution, the Internet has a cudnn version of the problem, but I can run successful your source code, so I hope to get your help, thank you!
problem

about the training

hello
I have a problem running the program

runntimeerror cudnn_status_not_initialized

Can you tell me how to solve it

Results after training for 300 epochs not satisfying

After training for 300 epochs with the standard configuration on the cityscapes dataset, except for the crop size, which I had to downsize to 256x256, the result is just a noise of some random pixel values. Does anyone encounter the same issues?

image

validation error

I tried to run the validation file and I have the following error

FileNotFoundError: [Errno 2] No such file or directory: './val_results/unsupervised/generated_labels/patient069_frame01_1.png'

could you please help me in solving the following error?

the optimization turn of G and D

Hi, I am wondering why you optimize G first, followed by D, which is different from traditional turn. Is there any difference and advantages?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.