Giter Club home page Giter Club logo

diff-unet's Introduction

Diff-UNet

Diff-UNet: A Diffusion Embedded Network for Volumetric Segmentation. Submitted to MICCAI2023.

https://arxiv.org/pdf/2303.10326.pdf

We design the Diff-UNet applying diffusion model to solve the 3D medical image segmentation problem.

Diff-UNet achieves more accuracy in multiple segmentation tasks compared with other 3D segmentation methods.

dataset

We release the codes which support the training and testing process of two datasets, BraTS2020 and BTCV.

BraTS2020(4 modalities and 3 segmentation targets): https://www.med.upenn.edu/cbica/brats2020/data.html

BTCV(1 modalities and 13 segmentation targets): https://www.synapse.org/#!Synapse:syn3193805/wiki/217789

Once the data is downloaded, you can begin the training process. Please see the dir of BraTS2020 and BTCV.

diff-unet's People

Contributors

ge-xing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

diff-unet's Issues

The reason why the author's results cannot be reproduced

作者您好!
我猜测在issue种很多人无法复现您的结果是因为您在您的代码种用os.listdir获得目录然后再划分训练集和测试集,一旦更换主机,os.listdir得到的文件的顺序是不一样的,所以我们最终划分的数据集和您的数据集是不一样的。我分别尝试在两台主机上打印os.listdir获得的目录顺序,发现其不一致。
image

主机一:
image

主机二:
image

我建议您对os.listdir后的all_dirs进行排序后再划分数据集!

I guess that many people in the issue cannot reproduce your results because you use os.listdir in your code to obtain the dirs and then divide the training set and test set. Once the host is changed, the order of the files obtained by os.listdir is not the same. The same, so the data set we finally divided is different from your data set. I tried to print the directory order obtained by os.listdir on the two hosts and found that it was inconsistent. @ge-xing

Segmentation fault

Hi ge-xing:
When I run test.py after training the model with train.py, I get the following error
image
How can I solve this problem

training problem

The loss jitter is large when training your own dataset, how to solve this problem?

Can not reproduce BraTS 2020 results.

Hi, thanks for sharing the code base. I try to reproduce the results on BraTS 2020 dataset, but the results I got are much worse than the paper. Here are the details:

For model training:
wt is 0.8498, tc is 0.4873, et is 0.4150, mean_dice is 0.5840

The tensorboard files are:
brats20-wt
brats20-tc
brats20-et
brats20-mean-dice
brats20-train-loss

The final model files are:
brats20-model-file

My settings are default settings:
env = "DDP"
max_epoch = 300
batch_size = 2
num_gpus = 4
GPU type: A100

Then I use the best model (best_model_0.5975.pt) to do evaluation on the test set, and I got:
brats20-test-dice
brats20-test-hd95

My python environment is:
Python 3.8.10
monai 1.1.0
numpy 1.22.2
SimpleITK 2.2.1
torch 1.13.0a0+936e930

The most strange thing is the segmentation performance of TC and ET is quite bad. Do you have any idea why the performance is so weird, and could you give me some advice on model training? BTW, could you please share the conda env file and your model weights for BraTS 2020 dataset? If you can create and share a docker image I think that could be perfect! Thanks.

Question on BTCV

self.embed_model = BasicUNetEncoder(3, 1, 2, [64, 64, 128, 256, 512, 64])
self.model = BasicUNetDe(3, 14, 13, [64, 64, 128, 256, 512, 64]

why the embed_model not be BasicUNetEncoder(3, 1, 13, [64, 64, 128, 256, 512, 64])
2.
why use the resample_img(), and if I run on my own dataset, in which case I should use it.

Train.py question

image

我在使用AbdomenCT-1K這包data遇到上述的問題,不知道你們有沒有解法,我針對MONAI版本去做調整目前都沒用。

how can i change the input channel from 4 to 1

Hello, your work is so great!!
i ve already got the results on brats2020,now iam work on my own datasets,but my dataset only have 1 modality (t1),i want to know where can i change the input channel from 4 to 1 ,and what else do i need to change,thx!!!!

Environment file requirement.

Hi. Thank you for your contribution.
Could you please share the environment file of your experiment? I met several re-implementation errors, which I think may due to different environment and packages.
Looking forward to your reply.

Training results using 2 GPUs

Thanks for releasing the paper and the source code!

I experimented with GPU=2 since I only have 2 GPUs.
I experimented with batch_size=4 and found that the mean_dice during training was abnormally low compared to other people's results.

Results of the experiment with batch_size=4
training mean_dice = 0.5956
wt = 0.9142
tc = 0.6453
et = 0.6193
mean_dice
train_loss
et

Are there any parameters or codes that should be changed by reducing the number of GPUs?

Issue with "mpi4py" library

I had an issue installing "mpi4py", is it possible to avoid using this library or is there an alternative?

Thank!

How is the embed_model trained

Hi,

I notice that the optimizer is defined only for the denoising unet model here

self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-4, weight_decay=1e-3)
, and I cannot find the optimizer for the embed_model (feature encoder in the paper). Does it mean the FE is just randomly initialized and not trained? Or are there something I am missing? Thanks.

received 0 items of ancdata

When I executed the train.py, I encountered the following error, how can I solve it?

Traceback (most recent call last):
File "train.py", line 192, in
trainer.train(train_dataset=train_ds, val_dataset=val_ds)
File "/home/hdc/zjh/Diff-UNet-main/LiTS/light_training/trainer.py", line 262, in train
self.train_epoch(
File "/home/hdc/zjh/Diff-UNet-main/LiTS/light_training/trainer.py", line 361, in train_epoch
for idx, batch in enumerate(loader):
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in next
data = self._next_data()
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1182, in _next_data
idx, data = self._get_data()
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1148, in _get_data
success, data = self._try_get_data()
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 986, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/multiprocessing/queues.py", line 116, in get
return _ForkingPickler.loads(res)
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd
fd = df.detach()
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/multiprocessing/resource_sharer.py", line 58, in detach
return reduction.recv_handle(conn)
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/multiprocessing/reduction.py", line 189, in recv_handle
return recvfds(s, 1)[0]
File "/home/hdc/anaconda3/envs/pytorch/lib/python3.8/multiprocessing/reduction.py", line 164, in recvfds
raise RuntimeError('received %d items of ancdata' %

train.py

Hello, I used BTCV code to run on the MSD Task-08Hepatic Vessel dataset and LITS dataset, but the Dice coefficients I obtained were both very low. What is the reason for this? If it's convenient, can you give me a BTCV model that you have trained? Thank you very much! My email is [email protected] Thank you very much!

Dice Score is low

I trained on the LiTS data set, and the Dice Score on the training set can reach 0.9035, but only 0.6898 on the test set. When I visualized the results, I found an interesting thing. The predicted shapes are very similar, but there seems to be some Offset, where might the problem be?
image

Function arguments of get_loader_btcv in test.py

BTCV test.py script's code below at line 151. The function get_loader_btcv's arguments set batch_size, but this function parameter does not have batch_size. Here is the function parameter get_loader_btcv(data_dir, cache=True). I assume data_dir=./RawData/Training because get_loader_btcv read image and label files and RawData/Testing has only img directory. Am I setting the function arguments correct? or Am I using a wrong dataset? Please let me know.

train_ds, val_ds, test_ds = get_loader_btcv(batch_size=batch_size, fold=0, cache=False)

Results after training

First of all, thanks for publicly sharing your paper and source-code.

I have used train.py to train a model on BraTS2020 dataset, but the results I got mismatches the results reported in the paper, do you have any idea why is that?

I got the following Dice scores (after training for 300 epochs):
wt is 0.8935, tc is 0.7762, et is 0.7637, mean_dice is 0.8111

About The Testing

From your figure 1 in the paper, we know your method directly predict x_0 in training, but why you inference the result step by step in testing? I do not understand.

About Train

def training_step(self, batch):
    image, label = self.get_input(batch)
    x_start = label

    x_start = (x_start) * 2 - 1
    x_t, t, noise = self.model(x=x_start, pred_type="q_sample")
    pred_xstart, pred_y = self.model(x=x_t, step=t, image=image, pred_type="denoise")

    loss_dice = self.dice_loss(pred_xstart, label)
    loss_bce = self.bce(pred_xstart, label)

    pred_xstart = torch.sigmoid(pred_xstart)
    loss_mse = self.mse(pred_xstart, label)

    loss = loss_dice + loss_bce + loss_mse

In the training phase, the noise generated from q_sample, the next step of prediction is the loss calculated with label. In my understanding, xt should calculate the loss with xt-1, and then generate x0 for the final target detection. But here why not calculate the loss with the next step, but directly calculate with the standard label?

AMOS22

大佬你好,该模型能否在AMOS22上运行,有没有考虑将Mamba与扩散模型结合

window_infer

Hi, thank you for your contributions.

I'm confused by the window_infer, which seems to be a MONAI function and is unable to handel the parameter pred_type="ddim_sample .

Diff-UNet/train.py

Lines 80 to 82 in 14d55bd

self.window_infer = SlidingWindowInferer(roi_size=[96, 96, 96],
sw_batch_size=1,
overlap=0.25)

Diff-UNet/train.py

Lines 123 to 128 in 14d55bd

def validation_step(self, batch):
image, label = self.get_input(batch)
output = self.window_infer(image, self.model, pred_type="ddim_sample")
output = torch.sigmoid(output)

BTCV train.py

Hello, I used BTCV code to run on the MSD Task-08Hepatic Vessel dataset and LITS dataset, but the Dice coefficients I obtained were both very low. What is the reason for this? If it's convenient, can you give me a BTCV model that you have trained? Thank you very much! My email is [email protected] Thank you very much!

multi-class

May I ask a question, can the model do multi-class segmentation?

BTCV train

Hello, how is the BTCV training set divided, and can you directly use the original dataset without data preprocessing?
I use the list of BTCV training datasets below and get an error.
btcv
btcv错误

Stuck in an epoch

When I apply this model to the verse2020 dataset, I get stuck at the ninth epoch every time (it will directly terminate the prompt RuntimeError: DataLoader worker (pid 9063) is killed by signal: killed) When I change the higher performance GPU and CPU, adjust the learning rate and batch, etc., I still get stuck at the ninth epoch, showing that it takes ten hours
42fe82b1178e5512182987a1c8228e6

MSD Liver

Can you share the code of MSD Liver
I want to reproduce the experimental results

Cannot execute test.py with cuda.

I always get an error below when I set device=device = "cuda:0"

File "test.py", line 170, in
v_mean, v_out = trainer.validation_single_gpu(val_dataset=test_ds)

Diff-UNet/BTCV/light_training/trainer.py", line 168, in validation_single_gpu
val_out = self.validation_step(batch)
File "test.py", line 112, in validation_step
output = self.window_infer(image, self.model, pred_type="ddim_sample")
File "/opt/monai/monai/inferers/inferer.py", line 521, in call
return sliding_window_inference(
File "/opt/monai/monai/inferers/utils.py", line 256, in sliding_window_inference
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "test.py", line 72, in forward
sample_return += sample.cpu()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

BTCV/test.py label input.

Program file "BTCV/text.py " a line of code at the line number 100 below. This line of code should be batch["label"] instead of batch["raw_label"]. Because batch["raw_label"] does not match with input's volume sizes.

label = batch["raw_label"] --> label = batch["label"]

image

Visualize the results

Thank you a lot for your great work. I'd like to visualize my Brats testing segmentation results, which parts of code should I modify ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.