Giter Club home page Giter Club logo

waveformer's Introduction

Unlocking Fine-Grained Details with Wavelet-based High-Frequency Enhancement in Transformers
$\text{(\textcolor{teal}{MICCAI 2023 MLMI Workshop})}$

arXiv

Medical image segmentation is a critical task that plays a vital role in diagnosis, treatment planning, and disease monitoring. Accurate segmentation of anatomical structures and abnormalities from medical images can aid in the early detection and treatment of various diseases. In this paper, we address the local feature deficiency of the Transformer model by carefully re-designing the self-attention map to produce accurate dense prediction in medical images. To this end, we first apply the wavelet transformation to decompose the input feature map into low-frequency (LF) and high-frequency (HF) subbands. The LF segment is associated with coarse-grained features while the HF components preserve fine-grained features such as texture and edge information. Next, we reformulate the self-attention operation using the efficient Transformer to perform both spatial and context attention on top of the frequency representation. Furthermore, to intensify the importance of the boundary information, we impose an additional attention map by creating a Gaussian pyramid on top of the HF components. Moreover, we propose a multi-scale context enhancement block within skip connections to adaptively model inter-scale dependencies to overcome the semantic gap among stages of the encoder and decoder modules. Throughout comprehensive experiments, we demonstrate the effectiveness of our strategy on multi-organ and skin lesion segmentation benchmarks.

Frequency Enhanced Transformer (FET) model
image

image
(a) FET block, (b) Multi-Scale Context Enhancement (MSCE) module

Citation

@inproceedings{azad2023unlocking,
  title={Unlocking Fine-Grained Details with Wavelet-based High-Frequency Enhancement in Transformers},
  author={Azad, Reza and Kazerouni, Amirhossein and Sulaiman, Alaa and Bozorgpour, Afshin and Aghdam, Ehsan Khodapanah and Jose, Abin, and Merhof, Dorit},
  maintitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  booktitle={Workshop on Machine Learning on Medical Imaging},
  year={2023}.
  organization={Springer}
}

News

  • Aug 18, 2023: Accepted in MICCAI 2023 MLMI Workshop! ๐Ÿฅณ

How to use

Requirements

  • Ubuntu 16.04 or higher
  • CUDA 11.1 or higher
  • Python v3.7 or higher
  • Pytorch v1.7 or higher
  • Hardware Spec
    • A single GPU with 12GB memory or larger capacity (we used RTX 3090)
einops==0.6.1
h5py==3.9.0
imgaug==0.4.0
matplotlib==3.6.2
MedPy==0.4.0
numpy==1.22.2
opencv_python==4.6.0.66
pandas==1.5.2
PyWavelets==1.4.1
scipy==1.6.3
SimpleITK==2.2.1
tensorboardX==2.6.2.2
timm==0.9.5
torch==1.14.0a0+44dac51
torchvision==0.15.0a0
tqdm==4.64.1

pip install -r requirements.txt

Model weights

You can download the learned weights in the following.

Dataset Model download link
Synapse FET [Download]

Training

For the training, you must run the train.py with your desired arguments or you can use the simple written bash script file in runs/run_tr_n01.sh. You need to change variables and arguments respectively. Below, you can find a brief description of the arguments.

usage: train.py [-h] [--root_path ROOT_PATH] [--test_path TEST_PATH] [--dataset DATASET] [--dstr_fast DSTR_FAST] [--en_lnum EN_LNUM] [--br_lnum BR_LNUM] [--de_lnum DE_LNUM]
              [--compact COMPACT] [--continue_tr CONTINUE_TR] [--optimizer OPTIMIZER] [--dice_loss_weight DICE_LOSS_WEIGHT] [--list_dir LIST_DIR] [--num_classes NUM_CLASSES]
              [--output_dir OUTPUT_DIR] [--max_iterations MAX_ITERATIONS] [--max_epochs MAX_EPOCHS] [--batch_size BATCH_SIZE] [--num_workers NUM_WORKERS]
              [--eval_interval EVAL_INTERVAL] [--model_name MODEL_NAME] [--n_gpu N_GPU] [--bridge_layers BRIDGE_LAYERS] [--deterministic DETERMINISTIC] [--base_lr BASE_LR]
              [--img_size IMG_SIZE] [--z_spacing Z_SPACING] [--seed SEED] [--opts OPTS [OPTS ...]] [--zip] [--cache-mode {no,full,part}] [--resume RESUME]
              [--accumulation-steps ACCUMULATION_STEPS] [--use-checkpoint] [--amp-opt-level {O0,O1,O2}] [--tag TAG] [--eval] [--throughput]

optional arguments:
-h, --help            show this help message and exit
--root_path ROOT_PATH
                      root dir for data
--test_path TEST_PATH
                      root dir for data
--dataset DATASET     experiment_name
--dstr_fast DSTR_FAST
                      SynapseDatasetFast: will load all data into RAM
--en_lnum EN_LNUM     en_lnum: Laplacian layers (Pyramid) for the encoder
--br_lnum BR_LNUM     br_lnum: Laplacian layers (Pyramid) for the bridge
--de_lnum DE_LNUM     de_lnum: Laplacian layers (Pyramid) for the decoder
--compact COMPACT     compact with 3 blocks instead of 4 blocks
--continue_tr CONTINUE_TR
                      continue the training from the last saved epoch
--optimizer OPTIMIZER
                      optimizer: [SGD, AdamW])
--dice_loss_weight DICE_LOSS_WEIGHT
                      You need to determine <x> (default=0.6): => [loss = (1-x)*ce_loss + x*dice_loss]
--list_dir LIST_DIR   list dir
--num_classes NUM_CLASSES
                      output channel of the network
--output_dir OUTPUT_DIR
                      output dir
--max_iterations MAX_ITERATIONS
                      maximum epoch number to train
--max_epochs MAX_EPOCHS
                      maximum epoch number to train
--batch_size BATCH_SIZE
                      batch_size per GPU
--num_workers NUM_WORKERS
                      num_workers
--eval_interval EVAL_INTERVAL
                      eval_interval
--model_name MODEL_NAME
                      model_name
--n_gpu N_GPU         total gpu
--bridge_layers BRIDGE_LAYERS
                      number of bridge layers
--deterministic DETERMINISTIC
                      whether using deterministic training
--base_lr BASE_LR     segmentation network learning rate
--img_size IMG_SIZE   input patch size of network input
--z_spacing Z_SPACING
                      z_spacing
--seed SEED           random seed
--opts OPTS [OPTS ...]
                      Modify config options by adding 'KEY VALUE' pairs.
--zip                 use zipped dataset instead of folder dataset
--cache-mode {no, full, part}
                      no: no cache, full: cache all data, part: sharding the dataset into nonoverlapping pieces and only cache one piece
--resume RESUME       resume from checkpoint
--accumulation-steps ACCUMULATION_STEPS
                      gradient accumulation steps
--use-checkpoint      whether to use gradient checkpointing to save memory
--amp-opt-level {O0,O1,O2}
                      mixed precision opt level, if O0, no amp is used
--tag TAG             tag of experiment
--eval                Perform evaluation only
--throughput          Test throughput only

Inference

For inference, you need to run the test.py. Most of the parameters are like for the train.py. You can also use runs/run_te_n01.sh for an instance.

To run with arbitrary weights you need to give the argument of --weights_fpath with the absolute path of the model weights file.

Experiments

Synapse Dataset

Synapse images Synapse results

ISIC 2018 Dataset

ISIC images ISIC results

References

waveformer's People

Contributors

afshinbigboy avatar rezazad68 avatar

Stargazers

Mohammad avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.