Giter Club home page Giter Club logo

imgx-diffseg's Introduction

Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation

πŸŽ‰ This work has been accepted at Deep Generative Models workshop at MICCAI 2023.

πŸ“‘ An updated manuscript has also been uploaded at arXiv.

πŸ”Ž We are working on a follow-up work, stay tuned.

figure2
figure2

Reproduction

Install the environment and build the dataset following the documentation. Then run one of the following sets of commands.

# 3D diffusion for Prostate MR
imgx_train --config-name config_pelvic_diffusion model.name=unet3d_time
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5

# 2D diffusion for Prostate MR
imgx_train --config-name config_pelvic_diffusion model.name=unet3d_slice_time
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5

# 3D diffusion for Abdominal CT
imgx_train --config-name config_amos_diffusion model.name=unet3d_time
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5

# 2D diffusion for Abdominal CT
imgx_train --config-name config_amos_diffusion model.name=unet3d_slice_time
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5

# 3D Non-diffusion for Prostate MR
imgx_train --config-name config_pelvic_segmentation model.name=unet3d
imgx_valid --log_dir wandb/latest-run/
imgx_test --log_dir wandb/latest-run/

# 2D Non-diffusion for Prostate MR
imgx_train --config-name config_pelvic_segmentation model.name=unet3d_slide
imgx_valid --log_dir wandb/latest-run/
imgx_test --log_dir wandb/latest-run/

# 3D Non-diffusion for Abdominal CT
imgx_train --config-name config_amos_segmentation model.name=unet3d
imgx_valid --log_dir wandb/latest-run/
imgx_test --log_dir wandb/latest-run/

# 2D Non-diffusion for Abdominal CT
imgx_train --config-name config_amos_segmentation model.name=unet3d_slide
imgx_valid --log_dir wandb/latest-run/
imgx_test --log_dir wandb/latest-run/

The ablation studies can be performed by adding one of following flags

  • Predict noise instead of mask: task.diffusion.model_out_type=epsilon
  • Do not use Dice loss: loss.dice=0
  • Do not recycle: task.diffusion.recycle=False
  • Change training denoising steps: task.diffusion.num_timesteps=1000

For instance, to deactivate Dice loss for 3D diffusion on Prostate MR data set:

imgx_train --config-name config_pelvic_diffusion model.name=unet3d_time loss.dice=0
imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5

The configurations are under imgx/conf/. If you have 2 GPUs and each GPU should take one image, please adjust the batch size to 2.

batch_size: 2 # number of devices * batch size per device (GPU/TPU)
batch_size_per_replica: 1 # batch size per device (GPU/TPU)

Environment Setup

TPU with Docker

The following instructions have been tested only for TPU-v3-8. The docker container uses root user.

  1. Build the docker image inside the repository.

    sudo docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) -f docker/Dockerfile.tpu -t imgx .

    where

    • --build-arg provides argument values.
    • -f provides the docker file.
    • -t tag the docker image.
  2. Run the Docker container.

    mkdir -p $(cd ../ && pwd)/tensorflow_datasets
    sudo docker run -it --rm --privileged --network host \
    -v "$(pwd)":/app/ImgX \
    -v "$(cd ../ && pwd)"/tensorflow_datasets:/root/tensorflow_datasets \
    imgx bash
  3. Install the package inside container.

    pip install -e .

GPU with Docker

The following instructions have been tested only for CUDA == 11.4.1 and CUDNN == 8.2.0. The docker container uses non-root user. Docker image used may be removed.

  1. Build the docker image inside the repository.

    docker build --build-arg HOST_UID=$(id -u) --build-arg HOST_GID=$(id -g) -f docker/Dockerfile -t imgx .

    where

    • --build-arg provides argument values.
    • -f provides the docker file.
    • -t tag the docker image.
  2. Run the Docker container.

    mkdir -p $(cd ../ && pwd)/tensorflow_datasets
    docker run -it --rm --gpus all \
    -v "$(pwd)":/app/ImgX \
    -v "$(cd ../ && pwd)"/tensorflow_datasets:/home/app/tensorflow_datasets \
    imgx bash

    where

    • --rm removes the container once exit it.
    • -v maps the ImgX folder into container.
  3. Install the package inside container.

    pip install -e .

Local with Conda

Install Conda for Mac M1

Download Miniforge from GitHub and install it.

conda install -y -n base conda-libmamba-solver
conda config --set solver libmamba
conda env update -f docker/environment_mac_m1.yml

Install Conda for Linux / Mac Intel

Install Conda and then create the environment.

conda install -y -n base conda-libmamba-solver
conda config --set solver libmamba
conda env update -f docker/environment.yml

Activate Conda Environment

Activate the environment and install the package.

conda activate imgx
pip install -e .

Data Processing

Male Pelvic MR

The data sets will be generated and processed by TFDS. It will be automatically downloaded from Zenodo to ~/tensorflow_datasets folder.

tfds build imgx/datasets/male_pelvic_mr

Optionally, add flag --overwrite to overwrite the generated data set.

tfds build imgx/datasets/male_pelvic_mr --overwrite

AMOS CT

The data sets will be generated and processed by TFDS. It will be automatically downloaded from Zenodo to ~/tensorflow_datasets folder.

tfds build imgx/datasets/amos_ct

Optionally, add flag --overwrite to overwrite the generated data set.

tfds build imgx/datasets/amos_ct --overwrite

Experiment

Training and Testing

Example command to use two GPUs for training.

export CUDA_VISIBLE_DEVICES="0,1"
imgx_train --config-name config_pelvic_segmentation
imgx_train --config-name config_pelvic_diffusion
imgx_train --config-name config_amos_segmentation
imgx_train --config-name config_amos_diffusion

After training, evaluate the trained models on the test data using the checkpoint having the best validation performance.

  1. For non-diffusion models:

    imgx_valid --log_dir wandb/latest-run/
    imgx_test --log_dir wandb/latest-run/
  2. For diffusion models, set num_seeds if using ensemble:

    imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5
    imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --num_seeds 5
    imgx_test_ensemble --log_dir wandb/latest-run/

The metrics are stored under wandb/latest-run/files/test_evaluation/

Code Quality

Pre-commit

Install pre-commit hooks:

pre-commit install
wily build .

Update hooks, and re-verify all files.

pre-commit autoupdate
pre-commit run --all-files

Code Test

Run the command below to test and get coverage report. As JAX tests requires two CPUs, -n 4 uses 4 threads, therefore requires 8 CPUs in total.

pytest --cov=imgx -n 4 tests

References

Acknowledgement

This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC Centre for Interventional and Surgical Sciences (203145Z/16/Z), the International Alliance for Cancer Early Detection, an alliance between Cancer Research UK (C28070/A30912, C73666/A31378), Canary Center at Stanford University, the University of Cambridge, OHSU Knight Cancer Institute, University College London and the University of Manchester, and Cloud TPUs from Google’s TPU Research Cloud (TRC).

imgx-diffseg's People

Contributors

mathpluscode avatar yipenghu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.