Giter Club home page Giter Club logo

motcat's Introduction

MOTCat

Multimodal Optimal Transport-based Co-Attention Transformer with Global Structure Consistency for Survival Prediction, ICCV 2023. [arxiv] [paper]
Yingxue Xu, Hao Chen
@InProceedings{Xu_2023_ICCV,
    author    = {Xu, Yingxue and Chen, Hao},
    title     = {Multimodal Optimal Transport-based Co-Attention Transformer with Global Structure Consistency for Survival Prediction},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {21241-21251}
}

Summary: Here is the official implementation of the paper "Multimodal Optimal Transport-based Co-Attention Transformer with Global Structure Consistency for Survival Prediction".

News

  • [01/04/2024] Upgraded OT module to its GPU version, which allows larger Micro-Batch or the removal of Micro-Batch setting. The Pre-requisites have been updated accordingly. In this case, we have set it to 16384, resulting in notably accelerated training speed. Based on the preliminary validation, the performance is consistent with the previous version. We will report the updated results soon.

Pre-requisites (new!!):

python==3.9.19
pot==0.9.3
torch==2.2.1
torchvision==0.17.1
scikit-survival==0.22.2

Prepare your data

WSIs

  1. Download diagnostic WSIs from TCGA
  2. Use the WSI processing tool provided by CLAM to extract resnet-50 pretrained 1024-dim feature for each 256 $\times$ 256 patch (20x), which we then save as .pt files for each WSI. So, we get one pt_files folder storing .pt files for all WSIs of one study.

The final structure of datasets should be as following:

DATA_ROOT_DIR/
    └──pt_files/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...

DATA_ROOT_DIR is the base directory of cancer type (e.g. the directory to TCGA_BLCA), which should be passed to the model with the argument --data_root_dir as shown in command.md.

Genomics

In this work, we directly use the preprocessed genomic data provided by MCAT, stored in folder dataset_csv.

Training-Validation Splits

Splits for each cancer type are found in the splits/5foldcv folder, which are randomly partitioned each dataset using 5-fold cross-validation. Each one contains splits_{k}.csv for k = 1 to 5. To compare with MCAT, we follow the same splits as that of MCAT.

Running Experiments

To train MOTCat, you can specify the argument in the bash train_motcat.sh stored in scripts and run the command:

sh scripts/train_motcat.sh

or use the following generic command-line and specify the arguments:

CUDA_VISIBLE_DEVICES=<DEVICE_ID> python main.py \
--data_root_dir <DATA_ROOT_DIR> \
--split_dir <SPLITS_FOR_CANCER_TYPE> \
--model_type motcat \
--bs_micro 256 \
--ot_impl pot-uot-l2 \
--ot_reg <OT_ENTROPIC_REGULARIZATION> --ot_tau 0.5 \
--which_splits 5foldcv \
--apply_sig

Commands for all experiments of MOTCat can be found in the command.md file.

Acknowledgements

Huge thanks to the authors of following open-source projects:

License & Citation

If you find our work useful in your research, please consider citing our paper at:

@InProceedings{Xu_2023_ICCV,
    author    = {Xu, Yingxue and Chen, Hao},
    title     = {Multimodal Optimal Transport-based Co-Attention Transformer with Global Structure Consistency for Survival Prediction},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {21241-21251}
}

This code is available for non-commercial academic purposes. If you have any question, feel free to email Yingxue XU.

motcat's People

Contributors

innse avatar

Stargazers

Wanxing Chang avatar  avatar DY avatar hitori bocchi avatar  avatar Sicheng Yang avatar KatMiaaao avatar  avatar  avatar  avatar Christian Engel avatar  avatar Huahui Yi avatar junjianli avatar  avatar  avatar  avatar SaberGuo avatar Yang Zekang avatar wilson avatar yilun avatar steve feng avatar JinquanGuan avatar  avatar Wu Qihang avatar Zekun Jiang avatar Chenxin Li avatar Daisy Jun avatar mhhan avatar Zelkova Luo avatar  avatar Linshan avatar  avatar

Watchers

 avatar

motcat's Issues

GBMLGG dataset

Hi @Innse , I am curious about how to get the GBMLGG dataset? I didn't find a specific GMBLGG dataset in the TCGA website and didn't see any notes about how to download it. Could you give me some hints about it? Thank you.

Duplicate files name of same patient

Hi author,

I want to ask where in your codes solved the problem of duplicate files name under a same patient. For example, here are several wsi feature files for one patient under a same case ID. So, where did you solve this problem in your dataset load codes? Thank you in advance!

Being stuck at "runing with mcat coattn" when running the MCAT model.

Hello, thank you for the excellent work you have completed and for providing the code. I wonder if you have ever encountered the issue of being stuck at "runing with mcat coattn" when running the MCAT model. This problem usually occurs during the training of the second fold. If you have experienced this issue before, I would appreciate your guidance on how to resolve it. Thank you in advance for your response.

pre-trained model

Hi author,

I didn't find where you provide your pre-trained model, could you please provide it through any link? And can you provide features files of .pt extracted from CLAM? Many thanks!

Question of C-Index

Hello author, I would like to ask if you have ever encountered the problem that loss decreases in the training set and verification set, and C-Index increases in the training set but remains unchanged or decreases in the verification set. May I ask how it was resolved? Looking forward to your reply!

Can you make the project's code open-source in the near future?

I recently had the opportunity to read your article, and I must say that I was highly impressed with both the descriptions and the experimental results presented. Your work has sparked a great deal of interest in me.
With regards to your project, I wanted to inquire about the availability of the code. I am keen on exploring and understanding the implementation details further, and having access to the code would greatly facilitate this process. I was wondering if you have plans to make the project's code open-source in the near future?
Thank you for considering my request, and I'm looking forward to your response.

What script is used in CLAM? How big of a GPU is needed to run MOTCat?

"Use the WSI processing tool provided by CLAM to extract resnet-50 pretrained 1024-dim feature for each 256
256 patch (20x), which we then save as .pt files for each WSI. So, we get one pt_files folder storing .pt files for all WSIs of one study."
use:python create_patches_fp.py --source DATA_DIRECTORY --save_dir RESULTS_DIRECTORY --patch_size 256 --preset tcga.csv --seg --patch --stitch
and python create_patches_fp.py --source DATA_DIRECTORY --save_dir RESULTS_DIRECTORY --patch_size 256 --seg --process_list process_list_edited.csv
Thank you very much for your project, and I look forward to your reply.

Problems in the processing of WSI dataset

I am using the histolab library to process the downloaded WSI dataset. The otsu threshold was followed to cut out patches containing greater than or equal to 50% of the tissue area (w/o overlap). However, during the course of the experiment it was found that the resultant metrics were not as expected. After checking, I believe it is a problem with my dataset processing code or paradigm, but due to local resource constraints there is no way to make multiple attempts. I hope to borrow the processing code from your work, and I would be more grateful if you can provide me with the feature files that have been extracted! Looking forward to your reply.

c-index evaluation question

Thank you very much for your work. I would like to ask why I encountered a situation where the final c-index evaluation metric is stuck at 0.5000 when I am modifying the model. Could you please explain why this is happening?

Nan values in the extracted feature

Hi there,

Thank you for sharing your nice work!
I met a problem when I try to train your model, it returned the nan loss and risk like below:
batch 99, loss: nan, label: 1, event_time: 14.6800, risk: nan

The error info are:
_File "/opt/anaconda3/envs/ct/lib/python3.7/site-packages/sksurv/metrics.py", line 223, in concordance_index_censored
event_indicator, event_time, estimate)
File "/opt/anaconda3/envs/ct/lib/python3.7/site-packages/sksurv/metrics.py", line 49, in _check_inputs
estimate = _check_estimate_1d(estimate, event_time)
File "/opt/anaconda3/envs/ct/lib/python3.7/site-packages/sksurv/metrics.py", line 36, in _check_estimate_1d
estimate = check_array(estimate, ensure_2d=False)
File "/opt/anaconda3/envs/ct/lib/python3.7/site-packages/sklearn/utils/validation.py", line 800, in check_array
_assert_all_finite(array, allow_nan=force_all_finite == "allow-nan")
File "/opt/anaconda3/envs/ct/lib/python3.7/site-packages/sklearn/utils/validation.py", line 116, in assert_all_finite
type_err, msg_dtype if msg_dtype is not None else X.dtype
ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

I checked the input and output of the model and found there are many nan values in the feature of both WSI and omic data which lead to the nan output of the hazards and S. I strictly followed the instructions you provided and really confused why this nan value would appear. If you met this problem before, could you tell me how to solve this?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.