Giter Club home page Giter Club logo

mims-harvard / txgnn Goto Github PK

View Code? Open in Web Editor NEW
56.0 3.0 18.0 7.86 MB

TxGNN: Zero-shot prediction of therapeutic use with geometric deep learning and clinician centered design

Home Page: https://zitniklab.hms.harvard.edu/projects/TxGNN

License: MIT License

Jupyter Notebook 88.47% Python 11.51% Shell 0.02%
drug-design drug-discovery geometric-deep-learning graph-neural-networks knowledge-graph therapeutics indications-and-warning precision-medicine

txgnn's Introduction

TxGNN: Zero-shot prediction of therapeutic use with geometric deep learning and human centered design

This repository hosts the official implementation of TxGNN, a model for identifying therapeutic opportunities for diseases with limited treatment options and minimal molecular understanding that leverages recent advances in geometric deep learning and human-centered.

TxGNN is a graph neural network pre-trained on a comprehensive knowledge graph of 17,080 clinically-recognized diseases and 7,957 therapeutic candidates. The model can process various therapeutic tasks, such as indication and contraindication prediction, in a unified formulation. Once trained, we show that TxGNN can perform zero-shot inference on new diseases without additional parameters or fine-tuning on ground truth labels.

TxGNN Explorer of model predictions and explanations is at http://txgnn.org

TxGNN

Installation

conda create --name txgnn_env python=3.8
conda activate txgnn_env
# Install PyTorch via https://pytorch.org/ with your CUDA versions
conda install -c dglteam dgl-cuda{$CUDA_VERSION}==0.5.2 # checkout https://www.dgl.ai/pages/start.html for more info, as long as it is DGL 0.5.2
pip install TxGNN

Note that if you want to use disease-area split, you should also install PyG following this instruction since some legacy data processing code uses PyG utility functions.

Core API Interface

Using the API, you can (1) reproduce the results in our paper and (2) train TxGNN on your own drug repurposing dataset using a few lines of code, and also generate graph explanations.

from txgnn import TxData, TxGNN, TxEval

# Download/load knowledge graph dataset
TxData = TxData(data_folder_path = './data')
TxData.prepare_split(split = 'complex_disease', seed = 42)
TxGNN = TxGNN(data = TxData, 
              weight_bias_track = False,
              proj_name = 'TxGNN', # wandb project name
              exp_name = 'TxGNN', # wandb experiment name
              device = 'cuda:0' # define your cuda device
              )

# Initialize a new model
TxGNN.model_initialize(n_hid = 100, # number of hidden dimensions
                      n_inp = 100, # number of input dimensions
                      n_out = 100, # number of output dimensions
                      proto = True, # whether to use metric learning module
                      proto_num = 3, # number of similar diseases to retrieve for augmentation
                      attention = False, # use attention layer (if use graph XAI, we turn this to false)
                      sim_measure = 'all_nodes_profile', # disease signature, choose from ['all_nodes_profile', 'protein_profile', 'protein_random_walk']
                      agg_measure = 'rarity', # how to aggregate sim disease emb with target disease emb, choose from ['rarity', 'avg']
                      num_walks = 200, # for protein_random_walk sim_measure, define number of sampled walks
                      path_length = 2 # for protein_random_walk sim_measure, define path length
                      )

Instead of initializing a new model, you can also load a saved model:

TxGNN.load_pretrained('./model_ckpt')

We provide an example pre-trained model weight at here.

To do pre-training using link prediction for all edge types, you can type:

TxGNN.pretrain(n_epoch = 2, 
               learning_rate = 1e-3,
               batch_size = 1024, 
               train_print_per_n = 20)

Lastly, to do finetuning on drug-disease relation with metric learning, you can type:

TxGNN.finetune(n_epoch = 500, 
               learning_rate = 5e-4,
               train_print_per_n = 5,
               valid_per_n = 20,
               save_name = finetune_result_path)

To save the trained model, you can type:

TxGNN.save_model('./model_ckpt')

To evaluate the model on the entire test set using disease-centric evaluation, you can type:

from txgnn import TxEval
TxEval = TxEval(model = TxGNN)
result = TxEval.eval_disease_centric(disease_idxs = 'test_set', 
                                     show_plot = False, 
                                     verbose = True, 
                                     save_result = True,
                                     return_raw = False,
                                     save_name = 'SAVE_PATH')

If you want to look at specific disease, you can also do:

result = TxEval.eval_disease_centric(disease_idxs = [9907.0, 12787.0], 
                                     relation = 'indication', 
                                     save_result = False)

After training a satisfying link prediction model, we can also train graph XAI model by:

TxGNN.train_graphmask(relation = 'indication',
                      learning_rate = 3e-4,
                      allowance = 0.005,
                      epochs_per_layer = 3,
                      penalty_scaling = 1,
                      valid_per_n = 20)

You can retrieve and save the graph XAI gates (whether or not an edge is important) into a pkl file located as SAVED_PATH/'graphmask_output_RELATION.pkl':

gates = TxGNN.retrieve_save_gates('SAVED_PATH')

Of course, you can save and load graphmask model as well via:

TxGNN.save_graphmask_model('./graphmask_model_ckpt')
TxGNN.load_pretrained_graphmask('./graphmask_model_ckpt')

Splits

There are numerous splits prepared in TxGNN. You can switch among them in the TxData.prepare_split(split = 'XXX', seed = 42) function.

  • complex_disease is the systematic split in the paper, where we first sample a set of diseases and then move all of their treatments to test set such that these diseases have zero treatments in training.
  • Disease area split first obtains a set of diseases in a disease area using disease ontology and move all of their treatments to the test set and then further removes a fraction of local neighborhood around these diseases to simulate the lack of molecular mechanism characterization of these diseases. There are nine disease areas: cell_proliferation, mental_health, cardiovascular, anemia, adrenal_gland, autoimmune, metabolic_disorder, diabetes, neurodigenerative
  • random is namely random splits which it randomly shuffles across drug-disease pairs. In the end, most of diseases have seen some treatments in the training set.

During deployment, when evaluate a specific disease, you may want to just mask this disease and use all of the other diseases. In this case, you can use TxData.prepare_split(split = 'disease_eval', disease_eval_idx = 'XX') where disease_eval_idx is the index of the disease of interest.

Another setting is to train the entire network without any disease masking. You can do that via split = 'full_graph'. This will automatically use 95% of data for training and 5% for validation set calculation to do early stopping. No test set is used.

Cite Us

MedRxiv preprint

@article{huang2023zeroshot,
  title={Zero-shot Prediction of Therapeutic Use with Geometric Deep Learning and Clinician Centered Design},
  author={Huang, Kexin and Chandak, Payal and Wang, Qianwen and Havaldar, Shreyas and Vaid, Akhil and Leskovec, Jure and Nadkarni, Girish and Glicksberg, Benjamin and Gehlenborg, Nils and Zitnik, Marinka},
  journal = {medRxiv},
  doi = {10.1101/2023.03.19.23287458},
  volume={},
  number={},
  pages={},
  year={2023},
  publisher={}
}

txgnn's People

Contributors

kexinhuang12345 avatar marinkaz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

txgnn's Issues

Issue Installing

I'm trying to install given the instructions. I make it as far as

conda create --name txgnn_env python=3.8
conda activate txgnn_env
# Install PyTorch via https://pytorch.org/ with your CUDA versions
conda install -c dglteam dgl-cuda{$CUDA_VERSION}==0.5.2 # checkout https://www.dgl.ai/pages/start.html for more info, as long as it is DGL 0.5.2

but pip install TxGNNgives me the pasted error.
Screenshot 2023-07-31 at 6 51 17 PM

Help would be greatly appreciated:)

UPDATE: I ended up doing a manual install, but just FYI.

Creating a DGL graph with `create_dgl_graph` function

I've been looking into the code, and this line seems a bit puzzling (numerical experiments don't show any difference in the model's performance w/ and w/o this line). What is so special about the 'effect/phenotype' and why does it need to be set to 0? Isn't it replaced by a positive number anyways in the succeeding for loop? Are there any edge cases where even though the 'effect/phenotype' is not in the df, we still want to keep this node type in the graph?

Problems when running the cardiovascular split

Hi,

When I try to run the cardiovascular split I encounter the following problems:

  • The object TxData downloads node.csv and not nodes.csv

    data_download_wrapper('https://dataverse.harvard.edu/api/access/datafile/6180617', os.path.join(self.data_folder, 'node.csv'))

  • When the object DataSplitter calls function load_kg and loads nodes.csv the sep='\t' is missing

    nodes = pd.read_csv(pth+'nodes.csv', low_memory=False)

  • When doing this merge, there is the problem that kg.csv has not columns x_index, y_index are they the same as x_id and y_id?

    select_kg = pd.merge(self.kg, disease_edges, 'right').drop_duplicates()

Thanks!

Empty heterograph is not allowed.

When I am trying to load the pretrained graph then I get following error.

TxGNN.load_pretrained('./TxGNNExplorer')

dgl._ffi.base.DGLError: [14:31:29] /tmp/dgl_src/src/graph/heterograph.cc:120: Check failed: !rel_graphs.empty(): Empty heterograph is not allowed.
The same error appears when I am trying to initialize the model.

this could be due to mistmatching version of DGL.
I am on Mac M1. So I cant install 0.5.2 version of DGL.
pip install "dgl==0.5.2"
ERROR: Could not find a version that satisfies the requirement dgl==0.5.2 (from versions: 0.9.0, 0.9.1, 1.0.0, 1.0.1, 1.1.0, 1.1.1, 1.1.2, 1.1.2.post1, 1.1.3, 2.0.0, 2.1.0)
ERROR: No matching distribution found for dgl==0.5.2

It seems that 0.5.2 is also not present in https://data.dgl.ai/wheels/repo.html.

Is it possible to create a new release with latest version of DGL?

Missing disease files for disease area split

Hi,
in process_disease_area_split to generate disease area splits you read in a disease_files/mental_health.csv. Do you have these uploaded somewhere? I was not able to find them in the Havard Dataverse.

disease_list = pd.read_csv(os.path.join(disease_file_path, split + '.csv'))

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.