Giter Club home page Giter Club logo

mld3 / deep-learning-applied-to-chest-x-rays-exploiting-and-preventing-shortcuts Goto Github PK

View Code? Open in Web Editor NEW
7.0 5.0 2.0 252 KB

[MLHC 2020] Deep Learning Applied to Chest X-Rays: Exploiting and Preventing Shortcuts (Jabbour, Fouhey, Kazerooni, Sjoding, Wiens). https://arxiv.org/abs/2009.10132

Home Page: http://proceedings.mlr.press/v126/jabbour20a

Python 30.99% Jupyter Notebook 69.01%
mlhc-2020 deep-learning chest-x-ray algorithmic-bias

deep-learning-applied-to-chest-x-rays-exploiting-and-preventing-shortcuts's Introduction

Overview

This is the code repository for the manuscript "Deep Learning Applied to Chest X-Rays: Exploiting and Preventing Shortcuts" and video.

Directory structures

  • preprocess_mimic_attributes/: contains a script to link MIMIC-CXR and MIMIC-IV to extract attribute features. Note that the link between MIMIC-CXR and MIMIC-IV was not publicly available at the time of publication, so this was added after the fact to enable other researchers to recreate more of our experiments.

  • preprocessing/ contains code for subsampling the datasets or assigned bias label for synthetic image preprocessing experiments

  • dataset/ contains data loaders

  • model/ contains model loader

  • standardizers/ is where the mean and standard deviation should be saved if you're standardizing to a subset of the data (e.g., when there is too much data to load onto disk)

Download and Process the Data

Follow directions to download the MIMIC-CXR, MIMIC-IV, and CheXpert datasets

Run resampling_skewed_unskewed.ipynp in preprocessing/ to subsample the data into skewed and unskewed datasets

Run assumption_relaxation_bias_assignment.ipynp and random_bias_assignment.ipynb in preprocessing/synthetic_shortcuts/ to assign bias labels, and then apply_filter.ipynb to apply the Gaussian filter to the images

Running the code

Config file

Pre-specified arguments can be set in config.json:

Required arguments:

  • csv_file: Path to metadata file.
  • checkpoint: Path to file location where model checkpoints will be saved.
  • labels: column name in the metadata files of the target classes. These can be separated by "|" (e.g., CHF|Pneumonia)
  • rotate degrees: degrees of rotation to use for random rotation in image augmentation.
  • disk: if disk = 1, all images will be loaded into memory before training. Otherwise during training images will be fetched from disk.
  • mask : mask = 1 if masked loss will be used (i.e., if there are missing labels). All missing labels in the metadata file should be set to -1.
  • early_stop: early_stop = 1 if early stopping criteria will be used. Otherwise model will train to 3 epochs.
  • pretrain: Whether or not to use an initialization. If pretrain is "yes", then ImageNet initialization will be used unless a pretrain file is specified. Otherwise, pretrain should be "random"
  • pretrain_file: file path to pretrained model (i.e., source task model or pretrained model on MIMIC-CXr and CheXpert)
  • pretrain_classes: number of target labels pretrain model had
  • loader_names : list of split names (i.e., ["train", "valid", "test"]). You do not have to include "test".

Optional arguments: These specify the number of layers to train. If none of are specified, the whole network will be trained.

  • tune_classifier: 1 to t train the final fully connected layer of the network
  • sensitivity_analysis: 1 to train a variable number of blocks in the network. If 1, the num_blocks argument should be specified.
  • num_blocks: number of denseblocks to train (increasing from the end of the network to the beginning) between 1 and 3.

Training a model

The following exmple code will train a model using train.py. Each run requires that a model_name and model_type be specificied. There are pre-specified in the config file along with other parameters (described in further detail below). Models will be saved in the directory chexpoint/model_type/model_name.

python train.py --model_type example_model_type --model_name example_model_name

Other non-required arguments are:

Arguments

  • gpu: specify the gpu numbers to train on, default is 0 and 1.
  • budget: number of hyperparameter combinations to try. Default is 50.
  • repeats: number of seed initializations to try. Default is 3.
  • save_every: for pretraining on MIMIC-CXR and CheXpert. Number of iterations to complete before saving a checkpoint. Default is None and will save after every epoch.
  • save_best_num, for pretraining on MIMIC-CXR and CheXpert. Number of top checkpoints to save (based on best AUROC performance on the validation set). Default is 1.
  • optimizer: optimzier to use. Default is "sgd", but can also choose "adam" for pretraining on MIMIC-CXR and CheXpert.

Pretraining/source task

To train a model on MIMIC-CXR and CheXpert, you'll want to use the save_every, save_best_num, and optimizer arguments. This will train on an ImageNet initialized model:

python train.py --model_type example_model_type --model_name example_model_name --save_every 4800 --save_best_num 10 --optimizer adam

Training the target task

To train a model after pretraining on either MIMIC-CXR/CheXpert or a source task, you'll need to specify the file location of the pretrained model in the config file.

deep-learning-applied-to-chest-x-rays-exploiting-and-preventing-shortcuts's People

Contributors

sjabbour avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

em3ndez eunkubae

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.