Giter Club home page Giter Club logo

stu-net's Introduction

STU-Net

STU-Net: Scalable and Transferable Medical Image Segmentation Models Empowered by Large-Scale Supervised Pre-training
Ziyan Huang, Haoyu Wang, Zhongying Deng, Jin Ye, Yanzhou Su, Hui Sun, Junjun He, Yun Gu, Lixu Gu, Shaoting Zhang, Yu Qiao
[Apr. 13, 2023] [arXiv, 2023]

News

  • STU-Net won the championship at the MICCAI 2023 ATLAS Challenge. leaderboard
  • STU-Net won the championship at the MICCAI 2023 SPPIN Challenge. leaderboard
  • STU-Net was the runner-up at the MICCAI 2023 AutoPET II Challenge (Highest DSC value). leaderboard
  • At the MICCAI 2023 BraTS2023 competition, STU-Net secured one runner-up and two third-place finishes. BraTS23
  • STU-Net took third place at the FLARE 2023 competition. leaderboard

Key Features

  • Scalability: STU-Net is designed for scalability, offering models of various sizes (S, B, L, H), including STU-Net-H, the largest medical image segmentation model to date with 1.4B parameters.
  • Transferability: STU-Net is pre-trained on a large-scale TotalSegmentator dataset (>100k annotations) and is capable of being fine-tuned for various downstream tasks.
  • Based on nnU-Net: Built upon the widely recognized nnUNet framework, STU-Net provides a robust and validated foundation for medical image segmentation.

Links

Details

Large-scale models pre-trained on large-scale datasets have profoundly advanced the development of deep learning. However, the state-of-the-art models for medical image segmentation are still small-scale, with their parameters only in the tens of millions. Further scaling them up to higher orders of magnitude is rarely explored. An overarching goal of exploring large-scale models is to train them on large-scale medical segmentation datasets for better transfer capacities. In this work, we design a series of Scalable and Transferable U-Net (STU-Net) models, with parameter sizes ranging from 14 million to 1.4 billion. Notably, the 1.4B STU-Net is the largest medical image segmentation model to date. Our STU-Net is based on nnU-Net framework due to its popularity and impressive performance. We first refine the default convolutional blocks in nnU-Net to make them scalable. Then, we empirically evaluate different scaling combinations of network depth and width, discovering that it is optimal to scale model depth and width together. We train our scalable STU-Net models on a large-scale TotalSegmentator dataset and find that increasing model size brings a stronger performance gain. This observation reveals that a large model is promising in medical image segmentation. Furthermore, we evaluate the transferability of our model on 14 downstream datasets for direct inference and 3 datasets for further fine-tuning, covering various modalities and segmentation targets. We observe good performance of our pre-trained model in both direct inference and fine-tuning.

Main Results

Segmentation performance comparison for different model sizes

With an increase in model size, the segmentation performance on large-scale datasets improves. Furthermore, larger models demonstrate greater data efficiency in medical image segmentation compared to their smaller counterparts.

With an increase in model size, universal models become capable of concurrently segmenting numerous categories, exhibiting significant performance advancements.

With an increase in model size, models trained on large-scale datasets exhibit stronger performance when directly inferring downstream tasks.

With an increase in model size, our STU-Net models pre-trained on the large-scale dataset, TotalSegmentator, demonstrate enhanced performance on downstream tasks, markedly surpassing models trained from scratch.

Dataset Links

We use TotalSegmentator dataset which contains 1204 images with 104 anatomical structures (consisting of 27 organs, 59 bones, 10 muscles and 8 vessels) for pre-training and 3 MICCAI 2022 challenge datasets as the downstream tasks for further fine-tuning.

Pre-training

Fine-tuning

Get Started

Main Requirements

torch==1.10
nnUNet==1.7.0
torchinfo

Installation

Our models are built based on nnUNet V1. Please ensure that you meet the requirements of nnUNet.

git clone https://github.com/Ziyan-Huang/STU-Net.git
cd nnUNet-1.7.1
pip install -e .

If you have installed nnUNetv1 already. You can just copy the following files in this repo to your nnUNet repository.

copy /network_training/* nnunet/training/network_training/
copy /network_architecture/* nnunet/network_architecture/
copy run_finetuning.py nnunet/run/

Pre-trained Models:

TotalSegmentator trained Models

These models are trained on TotalSegmentator dataset by 4000 epochs with mirror data augmentation

Model Name Crop Size #Params FLOPs Download Link
STU-Net-S 128x128x128 14.6M 0.13T Baidu Netdisk | Google Drive
STU-Net-B 128x128x128 58.26M 0.51T Baidu Netdisk | Google Drive
STU-Net-L 128x128x128 440.30M 3.81T Baidu Netdisk | Google Drive
STU-Net-H 128x128x128 1457.33M 12.60T Baidu Netdisk | Google Drive

Fine-tuning on downstream tasks

To perform fine-tuning on downstream tasks, use the following command with the base model as an example:

python run_finetuning.py 3d_fullres STUNetTrainer_base_ft TASKID FOLD -pretrained_weights MODEL

Please note that you may need to adjust the learning rate according to the specific downstream task. To do this, modify the learning rate in the corresponding Trainer (e.g., STUNetTrainer_base_ft) for the task.

Using Our Models for Inference

To use our trained models to conduct inference on CT images, please first organize the file structures in your RESULTS_FOLDER/nnUNet/3d_fullres/ as follows:

- Task101_TotalSegmentator/
  - STUNetTrainer_small__nnUNetPlansv2.1/
    - plans.pkl
    - fold_0/
      - small_ep4k.model
      - small_ep4k.model.pkl
  - STUNetTrainer_base__nnUNetPlansv2.1/
    - plans.pkl
    - fold_0/
      - base_ep4k.model
      - base_ep4k.model.pkl
  - STUNetTrainer_large__nnUNetPlansv2.1/
    - plans.pkl
    - fold_0/
      - large_ep4k.model
      - large_ep4k.model.pkl
  - STUNetTrainer_huge__nnUNetPlansv2.1/
    - plans.pkl
    - fold_0/
      - huge_ep4k.model
      - huge_ep4k.model.pkl

These pickle files can be found in the plan_files directory within this repository. You can download the models from the provided paths above and set TASKID and TASK_NAME according to your preferences.

To conduct inference, you can use following command (base model for example):

nnUNet_predict -i INPUT_PATH -o OUTPUT_PATH -t 101 -m 3d_fullres -f 0 -tr STUNetTrainer_base  -chk base_ep4k

For much faster inference speed with minimal performance loss, it is recommended to use the following command:

nnUNet_predict -i INPUT_PATH -o OUTPUT_PATH -t 101 -m 3d_fullres -f 0 -tr STUNetTrainer_base  -chk base_ep4k --mode fast --disable_tta

The categories corresponding to the label values can be found in the label_orders file within our repository (please note that this differs from the official TotalSegmentator version).

🙋‍♀️ Feedback and Contact

If you have any question, feel free to contact [email protected].

🛡️ License

This project is under the Apache License 2.0 license. See LICENSE for details.

🙏 Acknowledgement

Our code is based on the nnU-Net framework.

📝 Citation

If you find this repository useful, please consider citing our paper:

@misc{huang2023stunet,
      title={STU-Net: Scalable and Transferable Medical Image Segmentation Models Empowered by Large-Scale Supervised Pre-training}, 
      author={Ziyan Huang and Haoyu Wang and Zhongying Deng and Jin Ye and Yanzhou Su and Hui Sun and Junjun He and Yun Gu and Lixu Gu and Shaoting Zhang and Yu Qiao},
      year={2023},
      eprint={2304.06716},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

stu-net's People

Contributors

blueyo0 avatar ziyan-huang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

stu-net's Issues

about output channel and label

Hi ! I'm interesting in your excelent work. In STU-Net, the final 1×1×1 convolution layer for segmentation output has 105 channels, but the TotalSegmentator dataset contains 104 anatomical structures (104 classes), So is there a channel that contains a segmentation mask for all foreground ?

Hello, I am truly sorry for interrupting your work

Hello, I am truly sorry for interrupting your work, and I really apologize for the disturbance. May I kindly ask about the network you used during the ABUS competition? I am really sorry for any inconvenience caused.

Direct evaluation

Dear authors,

I wonder if you could help with clarifying how to conduct direct evaluation using STUNet? I have tried on several MSD datasets but the results were clearly wrong. Did you use the original nnUNetplans from total segmentor dataset?

Thanks!

Request for clarification on the architecture

Hello Ziyan,

Thanks you very much for the repo.

I am beginner and I am trying to understand Continous learning models.

I have a few question, I have read that in Continuous Deep Learning models there is an issue with Catastrophic Forgetfulness. How are you able to over come this issue. I apologize if my question is naive.

Also I was not able to download weights for FLARE23 from Baidu, could you please upload them to gdrive. I would like to try training on TCIA NSCLC dataset.

Regards,
Anil

How to train from scratch(re-pretraining)?

Thank you for sharing awesome work!

I did fine tuning for my dataset with pre-trained weight(small and base model)
It shows me great result.

Can you share the training script used to train the TotalsSegmentator dataset?
Or is it sufficient to use the method in this link?
https://github.com/wasserth/TotalSegmentator/blob/master/resources/train_nnunet.md

I would like to add more datasets to the pre-training model to maximise performance.

like Verse and CTPelvic1K or more,
Of course, I may have to change the information on their labels.

Could not find trainer class

I have put your files into the specific folder of nnunet, but I got this error raise RuntimeError("Could not find trainer class in nnunet.training.network_training") RuntimeError: Could not find trainer class in nnunet.training.network_training. It seems program can not find the trainer class, how can I do?

DP

how to train in 2 gpus?

Issues about pkl files

Really nice work!
I'm trying to use my private dataset to finetune the network, but it required a pkl file. I'm trying to use base_ep4k.model.pkl, but it raised an issue "KeyError: 'plans_per_stage'". Which pkl file should I provide here?

Label overlap

Is there any overlap of labels for different categories in the TotalSegmentator dataset? I noticed that some pixels seem to belong to multiple categories.

Some questions about the setting of initial_ir

The paper mentions the learning rate * 0.1 for the rest of the segmentation head. How is it reflected in the code? I can't find the corresponding part. If I want to make the learning rate of the segmentation head 0.01 and the rest *0.1, I should set self.initial_lr in STUNetTraniner_ft in STUNetTrainer.py to 0.01 or 0.001. My English may be so poor, please forgive me. Thank you very much!

Fine-tuning models with nnUNet v2 using multimodal data

Hello,

Thank you very much for the repository and the provided pretrained models. I have seen that you added some scripts for nnUNet v2 to run fine-tuning as well as the STUNetTrainer.
I have a dataset with CT and PET modalities and I was wondering if it is possible to use nnUNet v2 to run fine-tuning using a pretrained model that you provide (e.g. huge_ep4k.model) to do this combining both modalities (concatenating them as input as you showed for the AutoPET challenge).

Best regards,
Sergio

crop size is patch size?

crop size is patch size? And what is the relationship between crop size and patch size?
Thanks.

About the data preprocessing parameters for inference

Hi,

Thanks for sharing your great work with us. I just parsed the .pkl file you uploaded a while ago. And I found that there were no params of HU normalization (it was stored with the key "intensityproperties" based on my previous experience of nnunet). Also, I haven't found anything related to the process of HU normalization in your paper. I'm wondering if you did such HU normalization for CT images during inference and how you did that.

Dataset conversion instructions

I want to ask whether the fine-tune dataset you used in your work are just follow the preprocessing method provided by nnUNet_v1?Thank you very much if you can help me!

Missing <model_name>.model.pkl for inference

Really nice work! Would it be possible for you to kindly provide the '<model_name>.model.pkl' file along with each '<model_name>.model' file, as it is required for running inference of your models? Thank you very much.

pickle error occurs when running the Direct Inference command

Dear author, thank you so much for your work.
I'm impressed by your work and try to infer directly with your pretrained models. but things didn't work out for me.

I have followed your instructions in direct_inference and set up the environment. Then I used the example command python direct_inference.py STUNetTrainer_small example/Task032_AMOS22_Task1 example/result just as is written in your instructions. At first I thought this process would be smooth because I saw the terminal printed out

'...starting preprocessing generator
starting prediction...'

but suddenly an error occured:
File "D:\anaconda3\envs\pyTorch39\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) _pickle.PicklingError: Can't pickle <function <lambda> at 0x0000022D9F31155 0>: attribute lookup <lambda> on nnunet.utilities.nd_softmax failed [W CudaIPCTypes.cpp:16] Producer process has been terminated before all sha red CUDA tensors released. See Note [Sharing CUDA tensors]

File "D:\anaconda3\envs\pyTorch39\lib\multiprocessing\spawn.py", line 126 , in _main self = reduction.pickle.load(from_parent) EOFError: Ran out of input

the error message is quite long so I only paste the latest lines here.

I have searched for this error and it is said that

this error typically occurs when the pickle module in Python tries to load an empty file or a file that has been truncated. It means that the end-of-file was reached unexpectedly while there was still data expected to be read. This can happen if the file is empty, corrupted, or if there is a mismatch between how the data was written and how it is being read.

I have checked the RESULT_FOLDER's structure, but it was exactly the same as the structure in your instructions. and your pkl files must be fine. With all the possibilities excluded, now I have no clue why this happens. So I came to ask for your help.

I would appreciate it if you could offer some suggestions.
Thank you very much!

model infer error

    Hello, when I was doing model inference, under the nnUNetV1-based project, an error occurred when reading the model you provided, whether it was small, base, large or huge, the specific error was in 'all_params = [torch The error message displayed at .load(i, map_location=torch.device('cpu')) for i in all_best_model_files]' is: '_pickle.UnpicklingError: invalid load key, '%'.', that is, to read the provided **_ep4k.model.
     I would like to ask, according to the paper mentioned that stunet is trained on a graphics card with 80G memory, my device is an RTX4090 with 24G memory, is it caused by insufficient memory (but there is no such prompt in the nnunet error message), or because of the model The data is caused by the destruction of upload and download. Thank you so much.

huge model training supermemory problem

Hello!
I used a huge model to do finetune training. I had 80g of gpu memory, and still reported errors exceeding gpu memory, but when I looked at the gpu memory usage, the peak gpu memory only reached 30. How to solve this problem? Thank you!
RuntimeError: CUDA out of memory. Tried to allocate 1.25 GiB (GPU 0; 44.56 GiB total capacity; 41.63 GiB already allocated; 217.56 MiB free; 42.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF Exception in thread Thread-6:

Hyper-parameter of fine-tune

I am trying to reproduce your fine-tune results on the Amos 2022 dataset using STU-Net_small. What is the best hyper-parameter setting for this model and dataset. Much appreciated if you can provide!

New label

Thank you for sharing! Can I apply the pre-trained parameters and architecture of STU to a new dataset (not the previously mentioned 104 types of segmentation), as the features needed for segmentation are generic and I only need to fine-tune the output layer?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.