Giter Club home page Giter Club logo

scalable-data-valuation-health-care-shapley-value's Introduction

Trustworthy Machine Learning for Health Care: Scalable Data Valuation with the Shapley Value

This repository contains the code used for the paper "Trustworthy Machine Learning for Health Care: Scalable Data Valuation with the Shapley Value". The code was tested on a workstation running with Python 3.8 and PyTorch 1.6.0. For more details and references, please see the paper (open access).

If you use this repository and/or the corresponding paper, please cite it as follows:

@inproceedings{pandl2021trustworthy,
  title={Trustworthy Machine Learning for Health Care: Scalable Data Valuation with the Shapley Value},
  author={Pandl, Konstantin D. and Feiland, Fabian and Thiebes, Scott and Sunyaev, Ali},
  booktitle={Proceedings of the ACM conference on health, inference, and learning},
  year={2021}
}

Project Organization

General code to load and preprocess data

General section
    ├── data                    
    |    ├── CheXpert-v1.0-small   		<- dataset used (not checked in as the size is ~10GB; needs to be downloaded from https://stanfordmlgroup.github.io/competitions/chexpert/)
    |    ├── datasets              		<- CSV files containing the split used for training, validation and testing
    ├── arrays							
    |    ├── deep_features			<- stores the deep features created with get_deep.py
    |    ├── raw_data				<- stores the raw data created with get_array.py
    ├── models							
    |    ├── densenet.py			<- src defining Densenet-121 for classification task
    |    ├── resnet.py				<- src defining ResNet-34 for classification task
    ├── config.json				<- config file stores path to data
    ├── import_utils.py				<- src for importing pictures and labels from dataset
    ├── get_array.py				<- src for loading and storing raw data in arrays
    ├── get_deep.py				<- src for training, obtaining and storing deep features in arrays

Code to conduct various application experiments

Applications
    ├── applications
    |    ├── runtime comparison
    |           ├── DShap.py			<- src for running different data valuation methods
    |           ├── shap_utils.py		<- src for utils of DShap.py
    |           ├── runtime.py			<- src for runtime experiment
    |           ├── fit_module.py		<- src for training and evaluation of ML model

    |    ├── point_removal
    |           ├── knn_shap_calculation.py	<- src for calculating KNN-Shapley values
    |           ├── plot_densenet.py		<- src for training and evaluating model after points are removed
    |           ├── point_removal.py		<- src for point removal experiment
    |           ├── utils.py			<- src for utils of knnn_shap_calculation.py and plot_densenet.py
    
    |    ├── noisy_label					
    |           ├── flip_arrays					
    |                  ├── deep_features	<- stores the new deep features for flipped data created with get_new_deep.py
    |                  ├── raw_data		<- stores the flipped raw data created with generate_data.py
    |           ├── generate_flip_data.py	<- src for generating and storing flipped data in arrays
    |           ├── get_deep_flip.py		<- src for training, obtaining and storing deep features of flipped data in arrays
    |           ├── knn_shap_calculation_flip.py<- src for calculating KNN-Shapley values for flipped data
    |           ├── label_detection.py		<- src for detecting noisy label experiment
    |           ├── utils.py			<- src for utils of knnn_shap_calculation_flip.py and generate_flip_data.py

Code to plot the results

Plotting results
    ├── plot_results						
    |    ├── model
    |           ├── loss_plot.py		<- src for plotting Figure 2
    |           ├── auroc_plot.py		<- src for plotting Figure 3
    |           ├── pred.npy			<- stores predicted labels of trained model on validation set
    |           ├── true.npy			<- stores true labels of validation set

    |    ├── runtime_plot.py			<- src for plotting Figure 4
    
    |    ├── point_removal
    |           ├── point_removal_plot.py	<- src for plotting Figure 6
    |           ├── val_result_HtoL.npz		<- stores results from applications/point_removal/point_removal.py
    
    |    ├── label_detection					
    |           ├── label_plot.py		<- src for plotting Figure 7				
    |           ├── f_knn.pkl			<- stores information about fraction of incorrect labels detected
    |           ├── f_random.pkl		<- stores information about fraction of incorrect labels detected
    |           ├── x_knn.pkl			<- stores information about fraction of data inspected
    |           ├── x_random.pkl		<- stores information about fraction of data inspected

scalable-data-valuation-health-care-shapley-value's People

Contributors

kpandl avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

Forkers

mohamedyd

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.