Giter Club home page Giter Club logo

lwnet's Introduction

wnet

The Little W-Net that Could

You have reached the official repository for our work on retinal vessel segmentation with minimalistic models. The above picture represents a WNet architecture, which contains roughly around 70k parameters and closely matches (or outperforms) other more complicated techniques. For more details about our work, you can check the related paper:

The Little W-Net That Could: State-of-the-Art Retinal Vessel Segmentation with Minimalistic Models
Adrian Galdran, André Anjos, Jose Dolz, Hadi Chakor, Hervé Lombaert, Ismail Ben Ayed
https://arxiv.org/abs/2009.01907, Sep. 2020

We would appreciate if you could cite our work if it is useful for you :)

Note: If you are just looking for our results, you can directly download them at this link.

Please find below a table of contents describing what you can find in this repository:

Table of Contents

  1. Dependencies and getting the data ready
  2. Training a W-Net for vessel segmentation
  3. Generating segmentations
  4. Computing Performance
  5. Cross-Dataset Experiments
  6. Training with pseudo-labels and computing performance
  7. Evaluating your own model
  8. Training a W-Net for Artery/Vein segmentation
  9. Generating Artery/Vein segmentations
  10. Generating vessel and A/V segmentations on your own data

1. Dependencies and getting the data ready

First things first, clone this repo somewhere in your computer:

git clone https://github.com/agaldran/lwnet.git .

For full reproducibility, you should use the configuration specified in the requirements.txt file. If you are using conda, you can install dependencies in one line, just run on a terminal:

conda create --name lwnet --file environment.txt
conda activate lwnet

We have made an effort to automate the data download and preparation so that everything is as reproducible as possible. Out of the ten datasets we use in the paper, seven of them are public, and you can get them just running:

python get_public_data.py

This will populate the data directory with the seven sub-folders. If everything goes right, each sub-folder in data is named as the corresponding dataset, and contains at least:

  • Three folders called images, mask, manual
  • A csv file called test_all.csv

If the dataset is used in our work for training a vessel segmentation model (DRIVE, CHASE-DB, and HRF), you will also find:

  • Three csv files called train.csv, val.csv, test.csv

If the dataset also has Artery/Vein annotations, you will also see:

  • A folder called manual_av
  • A csv file called test_all_av.csv

If the dataset is used in our work for training an A/V models (DRIVE and HRF), you will also find:

  • Three csv files called train_av.csv, val_av.csv, test_av.csv

Note: The DRIVE dataset will also contain a folder called ZoneB_manual, which is used to evaluate A/V performance around the optic disc. The HRF dataset will also contain folders called images_resized, manual_resized, mask_resized. These are used only for training.

Note: The LES-AV dataset is still public but it now needs to be downloaded manually, please see the comments in get_public_data.py Line 400 forward for details.

2. Training a W-Net for vessel segmentation

Train a model on a given dataset. You also need to supply the path to save the model. Note that the training defaults to using the CPU, which is feasible due to the small size of our models. To reproduce our results in table 2 of our paper, you need to run:

python train_cyclical.py --csv_train data/DRIVE/train.csv --cycle_lens 20/50
                         --model_name wnet --save_path wnet_drive --device cuda:0
python train_cyclical.py --csv_train data/CHASEDB/train.csv --cycle_lens 40/50
                         --model_name wnet --save_path wnet_chasedb --device cuda:0
python train_cyclical.py --csv_train data/HRF/train.csv --cycle_lens 30/50
                         --model_name wnet --save_path wnet_hrf_1024
                         --im_size 1024 --batch_size 2 --grad_acc_steps 1 --device cuda:0

This will store the model weights in experiments/wnet_drive, experiments/wnet_chasedb, experiments/wnet_hrf respectively.

The parameter cycle_lens specifies the length of the training, and it is adjusted depending on the amount of images in the training set. For instance, in the DRIVE case, --cycle_lens 20/50 implies that we train for 20 cycles, each cycle running for 50 epochs. As CHASE-DB has less training images than DRIVE (8 vs 16), we double the number of cycles in that case.

Note that we use a batch_size of 4 by default, and that we train on HRF with an image size of 1024x1024. In order to train on a single GPU, we use gradient accumulation in that case.

3. Generating segmentations

Once the model is trained, you can produce the corresponding segmentations calling generate_results.py and specifying which dataset should be used:

python generate_results.py --config_file experiments/wnet_drive/config.cfg
                           --dataset DRIVE --device cuda:0
python generate_results.py --config_file experiments/wnet_chasedb/config.cfg
                           --dataset CHASEDB --device cuda:0
python generate_results.py --config_file experiments/wnet_hrf_1024/config.cfg
                           --dataset HRF --im_size 1024 --device cuda:0

The above stores the predictions for those datasets in results/DRIVE/experiments/wnet_drive, results/CHASEDB/experiments/wnet_chasedb, and results/HRF/experiments/wnet_hrf_1024 respectively.

4. Computing Performance

We call analyze_results.py to compute performance. It is important to specify what was the training and what is the test set here. For that, you pass the path to the train/test predictions, and the name of the train/test datasets:

python analyze_results.py --path_train_preds results/DRIVE/experiments/wnet_drive
                          --path_test_preds results/DRIVE/experiments/wnet_drive
                          --train_dataset DRIVE --test_dataset DRIVE

python analyze_results.py --path_train_preds results/CHASEDB/experiments/wnet_chasedb
                          --path_test_preds results/CHASEDB/experiments/wnet_chasedb
                          --train_dataset CHASEDB --test_dataset CHASEDB

python analyze_results.py --path_train_preds results/HRF/experiments/wnet_hrf_1024
                          --path_test_preds results/HRF/experiments/wnet_hrf_1024
                          --train_dataset HRF --test_dataset HRF

The code uses the csv files in each dataset folder to check which images should be used for running an AUC analysis in the training set and finding an optimal binarizing threshold to be used in the test set images.

5. Cross-Dataset Experiments

When a model has been trained on dataset A (say, DRIVE) and we want to test it on dataset B (say, CHASE-DB), we first generate segmentations on both datasets:

python generate_results.py --config_file experiments/wnet_drive/config.cfg
                           --dataset DRIVE  --device cuda:0
python generate_results.py --config_file experiments/wnet_drive/config.cfg
                           --dataset CHASEDB  --device cuda:0

and then we compute performance:

python analyze_results.py --path_train_preds results/DRIVE/experiments/wnet_drive
                          --path_test_preds results/CHASEDB/experiments/wnet_drive
                          --train_dataset DRIVE --test_dataset CHASEDB

6. Training with pseudo-labels and computing performance

  1. Train a model on a source dataset (DRIVE); this will store the model in experiments/wnet_drive
python train_cyclical.py --csv_train data/DRIVE/train.csv --cycle_lens 20/50
                         --model_name wnet --save_path wnet_drive
                         --device cuda:0
  1. Generate predictions on target dataset (CHASEDB) with this model; this will store predictions at results/CHASEDB/experiments/wnet_drive
python generate_results.py --config_file experiments/wnet_drive/config.cfg
                           --dataset CHASEDB --device cuda:0
  1. Train a model on DRIVE manual segmentations plus CHASEDB pseudo-segmentations for one cycle of 10 epochs with a lower learning rate, starting from the weights of the model trained on DRIVE. Note that in this case we use the AUC on the training set as checkpointing criterion. This training is slower because of the AUC computation on a large set of images at the end of each cycle. In this case, we save the new model in a folder called wnet_drive_chasedb_pl:
python train_cyclical.py --save_path wnet_drive_chasedb_pl
                         --checkpoint_folder experiments/wnet_drive
                         --csv_test data/CHASEDB/test_all.csv
                         --path_test_preds results/CHASEDB/experiments/wnet_drive
                         --max_lr 0.0001 --cycle_lens 10/1 --metric tr_auc
                         --device cuda:0
  1. Generate predictions with this new model on source dataset DRIVE:
python generate_results.py --config_file experiments/wnet_drive_chasedb_pl/config.cfg
                           --dataset DRIVE --device cuda:0
  1. Generate predictions on target dataset CHASEDB:
python generate_results.py --config_file experiments/wnet_drive_chasedb_pl/config.cfg
                           --dataset CHASEDB --device cuda:0
  1. Analyze results: we use DRIVE predictions to find optimal thresholding value:
python analyze_results.py --path_train_preds results/DRIVE/experiments/wnet_drive
                          --path_test_preds results/CHASEDB/experiments/wnet_drive_chasedb_pl
                          --train_dataset DRIVE --test_dataset CHASEDB

7. Evaluating your own model

We have made also an effort in making our evaluation protocol easy to use. You just need to build your own probabilistic segmentations with your segmentation system and store training/test predictions in folders called train_preds and trest_preds.

Be careful: you need to produce segmentations for the test dataset, and also for the training dataset, which we use to find an optimal threshold. Then you can call our code to compute performance. If you used dataset dataset_A for training and you want to test on dataset_B, you would run:

python analyze_results.py --path_train_preds train_preds --path_test_preds test_preds
                          --train_dataset dataset_A --test_dataset dataset_B

Be very careful to use the same train/test splits as we are using here (check the csvs in the corresponding dataset folder), or you might be testing on training data. Also, predictions should have the same exact name as the corresponding retinal images, but with a .png extension (otherwise the code will not find them).

8. Training a W-Net for Artery/Vein segmentation

In our work we train models on DRIVE and HRF, and we use a larger W-Net in this task. Again, HRF is trained at image size 1024x1024:

python train_cyclical.py --csv_train data/DRIVE/train_av.csv --model_name big_wnet
                         --cycle_len 40/50 --do_not_save False --save_path big_wnet_drive_av
                         --device cuda:0
python train_cyclical.py --csv_train data/HRF/train_av.csv --model_name big_wnet
                         --cycle_len 40/50 --do_not_save False --save_path big_wnet_hrf_av_1024
                         --im_size 1024 --batch_size 2 --grad_acc_steps 1  --device cuda:0

9. Generating Artery/Vein segmentations

This is similar to the vessel segmentation case, but calling generate_av_results.py instead::

python generate_av_results.py --config_file experiments/big_wnet_drive_av/config.cfg
                              --dataset DRIVE --device cuda:0
python generate_av_results.py --config_file experiments/big_wnet_drive_av/config.cfg
                              --dataset LES_AV --device cuda:0

and remember to set the image size for generating HRF segmentations:

python generate_av_results.py --config_file experiments/big_wnet_hrf_av_1024/config.cfg
                              --dataset HRF --im_size 1024 --device cuda:0

10. Generating vessel and A/V segmentations on your own data

To make it easy to construct segmentations on new data, we have also made available pretrained weights in the experiments/ folder, and a script you can call on your own images:

python predict_one_image.py --model_path experiments/wnet_drive/
                            --im_path folder/my_image.jpg
                            --result_path my_results/
                            --mask_path folder/my_mask.jpg
                            --device cuda:0
                            --bin_thresh 0.42

The script uses a model trained on DRIVE by default, you can change it to use a model that you would have trained on HRF (larger resolution but slower, see below). You can optionally pass the path to a FOV mask (if you do not, the code builds one for you), the device used for the forward pass of the network (defaults to CPU), and the binarizing threshold (by default set to the optimal one in the DRIVE training set, 0.42). If for instance you want to use a model trained on HRF, you will want to change the image size and the threshold as follows:

python predict_one_image.py --model_path experiments/wnet_hrf_1024/
                            --im_path folder/my_image.jpg
                            --result_path my_results/
                            --device cuda:0
                            --im_size 1024
                            --bin_thresh 0.3725

If you are interested in generating A/V segmentations, you can use a second script called predict_one_image_av.py. The usage is very similar (on the CPU in this case, automatic mask computing):

python predict_one_image_av.py --model_path experiments/big_wnet_drive/
                              --im_path folder/my_image.jpg
                              --result_path my_results/

Note that there is no need to supply a threshold in this case, since we take the argmax of the probabilities to generate hard segmentations.

If instead you want to use a model trained on HRF (also provided) at a larger resolution of 1024x1024, you would run:

python predict_one_image_av.py --model_path experiments/big_wnet_hrf_av_1024/ 
                               --im_path folder/my_image.jpg 
                               --result_path my_results/
                               --im_size 1024

Using this model should result in finer arteries and veins delineations (although not necessarily more accuracy), which may be desirable if your data is of higher resolution than DRIVE.

lwnet's People

Contributors

agaldran avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

lwnet's Issues

raise RuntimeError('Unable to find two maxima in histogram'

Hi!

First of all I wanted to say that your new software runs smoothly and accurately on independent datasets, so cheers and thanks for the awesome work!

For some images however, I am getting the following error for some reason (independent which pre-trained dataset I use)

* FOV mask not provided, generating it Traceback (most recent call last): File "predict_one_image_av.py", line 210, in <module> mask = get_fov(img) File "predict_one_image_av.py", line 83, in get_fov thresh = threshold_minimum(im_v) .../miniconda3/envs/lwnet/lib/python3.7/site-packages/skimage/filters/thresholding. py", line 721, in threshold_minimum raise RuntimeError('Unable to find two maxima in histogram') RuntimeError: Unable to find two maxima in histogram

I wonder what this could be, given that the image in question looks very good (sent it to you by mail).

Have a nice evening
Michael

a few questions about label/target

1
in get_loaders.py ,line 44:
target = self.label_encoding(target)
target = np.array(self.label_encoding(target))

why label_encoding run twice?

2
I build my own dataset like manual_av in DRIVE,
but why my manual_av.png has the gray degree as 0, 29, 76, 150 instead of 0, 85, 170, 255 after using "convert('L')"?
So I have to set label_values=[0, 150, 76, 29 ] not label_values=[0, 85, 170, 255].

THANK YOU!!!

why it consume so many time in cycle 1 epoch 0

such as
Cycle 1/40
0%| | 0/50 [13:15<?, ?it/
100%|████████████████████████████████████████████████████| 50/50 [13:46<00:00, 16.52s/it, tr_loss_lr=nan/0.000002]
------------------------- End of cycle, evaluating -------------------------
Train/Val Loss: nan/nan -- Train/Val AUC: 0.4871/0.5019 -- Train/Val DICE: 0.0000/0.0000 -- LR=0.01

but after that it seems to be normal(but i have no idea why loss and dice is abnormal and auc is so low)
Cycle 2/40
100%|██████████████████████████████████████████████████████████████| 50/50 [00:30<00:00, 1.62it/s, tr_loss_lr=nan/0.000002]
------------------------- End of cycle, evaluating -------------------------
Train/Val Loss: nan/nan -- Train/Val AUC: 0.5153/0.5019 -- Train/Val DICE: 0.0000/0.0000 -- LR=0.01
Cycle 24/40
100%|██████████████████████████████████████████████████████████████| 50/50 [00:30<00:00, 1.62it/s, tr_loss_lr=nan/0.000002]
------------------------- End of cycle, evaluating -------------------------
Train/Val Loss: nan/nan -- Train/Val AUC: 0.5067/0.5019 -- Train/Val DICE: 0.0000/0.0000 -- LR=0.01
Cycle 40/40
100%|██████████████████████████████████████████████████████████████| 50/50 [00:30<00:00, 1.62it/s, tr_loss_lr=nan/0.000002]
------------------------- End of cycle, evaluating -------------------------
Train/Val Loss: nan/nan -- Train/Val AUC: 0.5225/0.5019 -- Train/Val DICE: 0.0000/0.0000 -- LR=0.01

AttributeError: module 'torchvision.transforms' has no attribute 'InterpolationMode'

When running the scripts, I get an error:

Traceback (most recent call last): File "predict_one_image.py", line 8, in <module> from utils import paired_transforms_tv04 as p_tr File "/home/jovyan/work/retina-segment/utils/paired_transforms_tv04.py", line 37, in <module> T.InterpolationMode.NEAREST: 'PIL.Image.NEAREST', AttributeError: module 'torchvision.transforms' has no attribute 'InterpolationMode'

Looking at the history, it seems that these lines were changed a while ago, so did that break something, or should some of the environment libraries be changed to different versions?

What does the cycle stand for

Dear Author:
I looked at the code several times and I don't understand what cycle means.I debug the code and found that there was only one image in training stage(function train_one_epoch) instead of batch_size:4.
I know my question is easy,however i still can't understand after I debug the code,could you please anwer me?
Thanak you!!!

AV segmentation training

Hi Adrian,
First of all, your paper and code are great!
I'm using your code in order to train an av segmentation model, and you convert the GT segmentation to 1 channel (in label encoding function) but the Unets predictions have 4 channels. Is this on purpose? if not, should I convert the GT segmentation to 4 channels array as well? (and which channel should I assign to each category: vessel, artery, uncertain, background?)

Thanks,
Shvat

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.