Giter Club home page Giter Club logo

sae's Introduction

An Auto-Encoder Strategy for Adaptive Image Segmentation

Abstract

Deep neural networks are powerful tools for biomedical image segmentation. These models are often trained with heavy supervision, relying on pairs of images and corresponding voxel-level labels. However, obtaining segmentations of anatomical regions on a large number of cases can be prohibitively expensive. Furthermore, models trained with heavy supervision are often sensitive to shifts in image characteristics, for instance, due to a routine upgrade in scanner software. Thus there is a strong need for deep learning-based segmentation tools that do not require heavy supervision and can continuously adapt. In this paper, we propose a novel perspective of segmentation as a discrete representation learning problem, and present a variational autoencoder segmentation strategy that is flexible and adaptive. Our method, called Segmentation Auto-Encoder (SAE), leverages all available unlabeled scans and merely requires a segmentation prior, which can be a single unpaired segmentation image. In experiments, we apply SAE to brain MRI scans. Our results show that SAE can produce good quality segmentations, particularly when the prior is good. We demonstrate that a Markov Random Field prior can yield significantly better results than a spatially independent prior.

Requirements

The code was tested on:

  • python 3.6
  • pytorch 1.1
  • torchvision 0.3.0
  • scikit-image 0.15.0
  • scikit-learn 0.19.1
  • matplotlib 3.0.2
  • numpy 1.15.4
  • tqdm 4.38.0

Instruction

Prepocessing

T1-weighted 3D brain MRI scans was preprocessed using FreeSurfer [1]. This includes skull stripping, bias-field correction, intensity normalization, affine registration to Talairach space, and resampling to 1 mm3 isotropic resolution. The original images had a shape of 256x256x256, which was futher cut using ((48, 48), (31, 33), (3, 29)) along the sides to elimate empty space.

Unfortunately, Buckner40 and its manual segmentation are not public dataset. However, we provide some example OASIS [2] volumes in ./data/vols/ and automatic segmentations in ./data/labels/. The labels legend can be found in the FreeSurfer webpage. We also provide the probabilistic atlas that we used in our experiment in the data folder. Check functions/dataset.py for further details on how we convert these numpy files into torch tensor.

Training

Once your MRI has been preprocessed, the next step is to obtain a good initialization for the encoder. This is accomplish by first mapping your training brain MRI to the probabilistic atlas. As an example, we provide our initialization in ./weights/pretrained_encoder.pth.tar. However, you should pretrain your own encoder for your dataset in order to obtain the good results.

To train SAE, run python train.py 0 2 --compute_dice

Choose which of your gpus you want to use through args.gpus. Note that our current model only support --args.batch_size=1 due to memory and computational constraint.

One important parameters in the script is args.sigma. Setting args.sigma = 2 allows you to estimate the variance σ in 1/(2σ2) ||x-x'|| as described in the paper. Setting args.sigma = 0 allows you to set a fixed weight α to the reconstruction term α||x-x'||

The parameter --args.beta puts weight on the Lmrf. If this term is not 0, --args.k is used to determine the size of the neighboorhood constraint.

Finally, --compute_dice allows you to track the dice score of your prediction agaisnt the ground truth label. This is not used for training or model selection, but it can help you debug your code during training.

Evaluation

Run python test.py 0. It will return a dictionary which can be loaded using torch.load(result.pth.tar')['stat']. It has 13 columns. The 12 first column contain the dice loss (1-dice) of 12 region of interest (rois). Namely: pallidum, amygdala, caudate, cerebral cortex, hippocampus, thalamus, putamen, cerebral wm, cerebellum cortex, lateral ventricle, cerebellum wm and brain-stem. The last column is the average of the rois.

Contact

Feel free to open an issue in github for any problems or questions.

References

[1] Bruce Fischl. Freesurfer. Neuroimage, 62(2):774–781, 2012

[2] Marcus et al. Open access series of imaging studies (OASIS): cross-sectional MRI data in young, middle aged, nondemented, and demented older adults. Journal of Cognitive Neuroscience, 2007.

sae's People

Contributors

evanmy avatar varunj avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

sae's Issues

'EnforcePriorBackward' object has no attribute 'prior'

Dear author,

Thank you for your great paper!

I am trying to run your code (python3 test.py 0). However, I got an error as below.

 Traceback (most recent call last):
   File "test.py", line 144, in <module>
     stat, file_names = run(loader= [dataset, test_idx],
   File "test.py", line 67, in run
     out = m.enforcer(prior, out) 
   File "/vinbrain/pqdung_train/sae/functions/models.py", line 107, in enforcer
     return EnforcePrior(prior)(x)
   File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 159, in __call__
     raise RuntimeError(
 RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)

Then I tried to fix this bug by modifying your code in functions/models.py as below:

image

Unfortunately, I got another error:

Traceback (most recent call last):
   File "test.py", line 144, in <module>
     stat, file_names = run(loader= [dataset, test_idx],
   File "test.py", line 67, in run
     out = m.enforcer(prior, out) 
   File "/vinbrain/pqdung_train/sae/functions/models.py", line 110, in enforcer
     return EnforcePrior(prior).apply(x)
   File "/vinbrain/pqdung_train/sae/functions/models.py", line 98, in forward
     self.forbidden_regs = (self.prior < self.eps).float()
 AttributeError: 'EnforcePriorBackward' object has no attribute 'prior'

Could you please help me to find out the problem?

Thank you in advance!

Training loss backpropagation error

Getting this when I run using python train.py 0 2 --beta 0

Traceback (most recent call last):
File "train.py", line 312, in
compute_dice= args.compute_dice)
File "train.py", line 153, in run
loss.backward()
File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/opt/conda/lib/python3.7/site-packages/torch/autograd/init.py", line 149, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: function EnforcePriorBackward returned an incorrect number of gradients (expected 2, got 1)

How to train our own dataset

Thanks for this interesting repo.
this is not an issue but for an advice. Can we apply this model to other biomedical dataset for tissue and cell segmentation?

Thanks

'AssertionError: The prior should be one-hot encoded byte' during training

Dear author,

Thank you for your contribution! I've read your paper and I am trying to run the code !python train.py 0 2 --compute_dice on Colab, however, I got an error:

1639254062(1)

Then I made a test on the function argmax_ch, and it turned out that the data type of the output of argmax_ch is boolean not torch.unint8:

1639254479(1)

I'm a little confused, could you pleas help me fix this problem?

Thank you in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.