Giter Club home page Giter Club logo

agileformer's Introduction

AgileFormer

This repository contains official implementation for the paper titled "AgileFormer: Spatially Agile Transformer UNet for Medical Image Segmentation" paper

News 🔥

  • April 12, 2024: The code for 2D segmentation is ready to run. Welcome to evaluate the pretrained models on Synapse dataset.
  • April 18, 2024: The code has supported the implementation of deformable convolution in mmcv and plain PyTorch. But this requires to retrain the model by your own.

Abstract. In the past decades, deep neural networks, particularly convolutional neural networks, have achieved state-of-the-art performance in a variety of medical image segmentation tasks. Recently, the introduction of the vision transformer (ViT) has significantly altered the landscape of deep segmentation models. There has been a growing focus on ViTs, driven by their excellent performance and scalability. However, we argue that the current design of the vision transformer-based UNet (ViT-UNet) segmentation models may not effectively handle the heterogeneous appearance (e.g., varying shapes and sizes) of objects of interest in medical image segmentation tasks. To tackle this challenge, we present a structured approach to introduce spatially dynamic components to the ViT-UNet. This adaptation enables the model to effectively capture features of target objects with diverse appearances. This is achieved by three main components: (i) deformable patch embedding; (ii) spatially dynamic multi-head attention; (iii) deformable positional encoding. These components were integrated into a novel architecture, termed AgileFormer. AgileFormer is a spatially agile ViT-UNet designed for medical image segmentation. Experiments in three segmentation tasks using publicly available datasets demonstrated the effectiveness of the proposed method.

Architecture Method

1. Prepare data

Put pretrained weights into folder "data/" under the main "AgileFormer" directory, e.g., "data/Synapse", "data/ACDC".

2. Environment

  • We recommend an evironment with python >= 3.8, and then install the following dependencies:
pip install -r requirements.txt
  • We recommend to install Neighborhood Attention (NATTEN) and Defomrable Convolution manually for compatability issues:

    • [NATTEN] Please refer to https://shi-labs.com/natten to install NATTEN with correct CUDA and PyTorch versions (Note: we trained the model using CUDA 12.1 + PyTorch 2.2, and NATTEN=0.15.1). For example, we can install NATTEN with Pytorch 2.2 and CUDA 12.1 with
    pip3 install natten==0.15.1+torch220cu121 -f https://shi-labs.com/natten/wheels/
    
    • [Deformable Convolution] There are many implementation of deformable convolution:
      • [tvdcn] We recommend the implementation in tvdcn (https://github.com/inspiros/tvdcn), as it provides CUDA implementation of both 2D/3D deformable convolution (The 2D implementation of deformable convolution in tvdcn should be the same as that provided by PyTorch) [Note: We used tvdcn for our experiments] For example, we can install latest tvdcn with Pytorch >= 2.1 and CUDA >= 12.1 with
      pip install tvdcn
      
      • [mmcv] We also provide an alternative implementaiton of deformable convolution in mmcv (https://github.com/open-mmlab/mmcv). This is the most widely used version; but it only provides 2D CUDA implementation. The installation of mmcv is quite straightforward with (you may need to check PyTorch and CUDA version as well)
      pip install -U openmim 
      mim install mmcv
      
      • [vanilla PyTorch] We also provide the implementation provided by official PyTorch
      • Note: Our code will search all the aforementioned three options in order: if tvdcn is installed, we will use it; elif mmcv is installed, we will use mmcv; else we will use implementation provided by Pytorch.
  • Final Takeaway: We suggest installing PyTorch >= 2.1, CUDA >= 12.1 for better compatability of all pacakges (especially tvdcn and natten). It is also possible to install those two packages with lower PyTorch and CUDA version, but they may need to be built from source.

3. Evaluate Pretrained Models

We provide the pretrained models in the tiny and base versions of AgileFormer, as listed below.

task model size resolution DSC (%) config pretrained weights
Synapse multi-organ Tiny 224x224 83.59 config GoogleDrive / OneDrive
Synapse multi-organ Base 224x224 85.74 config GoogleDrive / OneDrive
ACDC cardiac Tiny 224x224 91.76 config
ACDC cardiac Base 224x224 92.55 config
Decathlon brain tumor Tiny 96x96x96 85.7 config

Put pretrained weights into folder "pretrained_ckpt/[dataset_name (e.g., Synapse)]" under the main "AgileFormer" directory

python test.py --cfg [pretrained_config_file in configs]

For example, for Synapse base model, run the following command:

python test.py --cfg configs/agileFormer_base_synapse_pretrained_w_DS.yaml

4. Train From Scratch

a. Download pre-trained deformable attention weights (DAT++)

model resolution pretrained weights
Tiny 224x224 OneDrive / TsinghuaCloud
Base 224x224 OneDrive / TsinghuaCloud

If you are interested in more pretrained weights (e.g., with different resolutions, model sizes, and tasks), please check with the official repo in DAT++: (https://github.com/LeapLabTHU/DAT)

Put pretrained weights into folder "pretrained_ckpt/" under the main "AgileFormer" directory

b. Run the training script

python train.py --cfg [config_file in configs]

For example, for training Synapse tiny model, run the following command:

python train.py --cfg configs/agileFormer_tiny.yaml 

Future Updates

  • Release the tentative code for 2D segmentation.
  • Release the pretrained code for 2D segmentation.
  • Support the implementation of deformable convolution in mmcv and pytorch
  • Reorganize the tentative code for easier usage (maybe).
  • Release the code for 3D segmentation.
  • Release the pretrained code for 3D segmentation.

Acknowledgements

This code is built on the top of Swin UNet and DAT, we thank to their efficient and neat codebase.

Citation

If you find our work is useful in your research, please consider raising a star ⭐ and citing:

@article{qiu2024agileformer,
  title={AgileFormer: Spatially Agile Transformer UNet for Medical Image Segmentation},
  author={Qiu, Peijie and Yang, Jin and Kumar, Sayantan and Ghosh, Soumyendu Sekhar and Sotiras, Aristeidis},
  journal={arXiv preprint arXiv:2404.00122},
  year={2024}
}

agileformer's People

Contributors

peijie-chiu avatar

Stargazers

 avatar legendchilli avatar Xinyu Liu avatar Raymond Yuan SHANG avatar Ian avatar  avatar Sang avatar 魔鬼面具 avatar  avatar Mario Pascual González avatar  avatar  avatar 乐云一 avatar  avatar Yang Jin avatar yeshunlong avatar Mingyang Wu avatar Ajitabh Kumar avatar  avatar  avatar  avatar An-zhi WANG avatar  avatar Sayantan Kumar avatar  avatar  avatar He avatar Angus avatar  avatar  avatar  avatar SH avatar  avatar

Watchers

Sung Min Ha avatar Kostas Georgiou avatar eryang-zhy avatar  avatar

agileformer's Issues

package

I encountered an issue while installing the package with "RuntimeError: Couldn't load custom C++ops. Recompile C++extension with: Python setup.py build_ext -- in-place". I understand that the version of torch vision needs to be greater than 0.3 to use ops, but the torch vision version that is compatible with CUDA 12.1+PyTorch 2.2 is 0.17. Can you please provide the complete package and version? thanks

An attempt was made to import the na2d_qk_with'bias function from the natten.functional file, but it did not exist.

Dear author, I am very excited after reading your paper and feel that you have written it very well. However, when reproducing the code, I encountered this issue. My natten and torch are both 2.20 and cu118, and there are no other issues. However, the natten.functional function does not include na2d_qk_with'bias

File "/root/data1/zhangxuanxuan/AgileFormer/networks/nat_2d.py", line 30, in
from natten.functional import na2d_av, na2d_qk_with_bias
ImportError: cannot import name 'na2d_qk_with_bias' from 'natten.functional' (/root/anaconda3/envs/AF/lib/python3.10/site-packages/natten/functional.py)

dice problem

这个dice计算有点问题吧。
image
这是开源框架monai计算方式
image

Thank you very much for your code, which has helped me a lot! But I have a question for you. Thank you

/home/liubn/anaconda3/envs/swin_umamba/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/liubn/anaconda3/envs/swin_umamba/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c1017RegisterOperatorsD1Ev'If you don't plan on using image functionality from torchvision.io, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have libjpeg or libpng installed before building torchvision from source?
warn(
tvdcn is installed, using it for deformable convolution
Traceback (most recent call last):
File "/home/liubn/anaconda3/envs/swin_umamba/lib/python3.10/site-packages/natten/functional.py", line 32, in
from natten import libnatten # type: ignore
ImportError: /home/liubn/anaconda3/envs/swin_umamba/lib/python3.10/site-packages/natten/libnatten.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4cuda19getDevicePropertiesEl

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/liubn/0-149liubaoning/43-AgileFormer-main/train.py", line 12, in
from networks.agileFormer_2d import AgileFormer2D as ViT_seg
File "/home/liubn/0-149liubaoning/43-AgileFormer-main/networks/agileFormer_2d.py", line 3, in
from .agileFormer_sys_2d import AgileFormerSys2D
File "/home/liubn/0-149liubaoning/43-AgileFormer-main/networks/agileFormer_sys_2d.py", line 9, in
from .nat_2d import NeighborhoodAttention2D
File "/home/liubn/0-149liubaoning/43-AgileFormer-main/networks/nat_2d.py", line 30, in
from natten.functional import na2d_av, na2d_qk_with_bias
File "/home/liubn/anaconda3/envs/swin_umamba/lib/python3.10/site-packages/natten/init.py", line 24, in
from .functional import (
File "/home/liubn/anaconda3/envs/swin_umamba/lib/python3.10/site-packages/natten/functional.py", line 34, in
raise ImportError(
ImportError: Failed to import NATTEN's CPP backend. This could be due to an invalid/incomplete install. Please uninstall NATTEN (pip uninstall natten) and re-install with the correct torch build: shi-labs.com/natten .

Error with Load Weight and Configuration Modification Impacting Accuracy: Base Model Training Inquiry

Thank you so much for doing such a good job, I did get 86.1 accuracy with your base_best weights, but I used 3090 graphics (unlike your V100) and the environment was the same. I directly used your commands to train the model, but there was an error of load weight. I modified your configuration file and trained Base from the beginning, and the accuracy was only 84.7. I am not sure if there is any error in your configuration file, or could you please tell me how you trained base 85.7, thank you very much

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.