Giter Club home page Giter Club logo

gears's Introduction

GEARS: Predicting transcriptional outcomes of novel multi-gene perturbations

This repository hosts the official implementation of GEARS, a method that can predict transcriptional response to both single and multi-gene perturbations using single-cell RNA-sequencing data from perturbational screens.

gears

Installation

Install PyG, and then do pip install cell-gears.

[New] Updates in v0.1.1

  • Fixed training breakpoint bug from v0.1.0
  • Preprocessed dataloader now available for Replogle 2022 RPE1 and K562 essential datasets
  • Added custom split, fixed no-test split

A note on usage:

  • GEARS is currently not designed to handle training across multiple cell types or cross-cell type transfer of predictions
  • GEARS has not been tested for training with bulk sequencing data.
  • When trained on single-gene perturbation data alone, GEARS cannot reliably predict outcomes for combinatorial perturbations. The model must be trained on some combinatorial perturbation data to make such predictions.
  • GEARS has been tested using datasets that contain multiple perturbation types, and multiple cells within each perturbation type. Datasets with too few cells per perturbation or too few perturbations may not work well with our model.

Core API Interface

Using the API, you can (1) reproduce the results in our paper and (2) train GEARS on your perturbation dataset using a few lines of code.

from gears import PertData, GEARS

# get data
pert_data = PertData('./data')
# load dataset in paper: norman, adamson, dixit.
pert_data.load(data_name = 'norman')
# specify data split
pert_data.prepare_split(split = 'simulation', seed = 1)
# get dataloader with batch size
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)

# set up and train a model
gears_model = GEARS(pert_data, device = 'cuda:8')
gears_model.model_initialize(hidden_size = 64)
gears_model.train(epochs = 20)

# save/load model
gears_model.save_model('gears')
gears_model.load_pretrained('gears')

# predict
gears_model.predict([['CBL', 'CNN1'], ['FEV']])
gears_model.GI_predict(['CBL', 'CNN1'], GI_genes_file=None)

To use your own dataset, create a scanpy adata object with a gene_name column in adata.var, and two columns condition, cell_type in adata.obs. Then run:

pert_data.new_data_process(dataset_name = 'XXX', adata = adata)
# to load the processed data
pert_data.load(data_path = './data/XXX')

Demos

Name Description
Dataset Tutorial Tutorial on how to use the dataset loader and read customized data
Model Tutorial Tutorial on how to train GEARS
Plot top 20 DE genes Tutorial on how to plot the top 20 DE genes
Uncertainty Tutorial on how to train an uncertainty-aware GEARS model

Colab

Name Description
Using Trained Model Use a model trained on Norman et al. 2019 to make predictions (Needs Colab Pro)

Cite Us

@article{roohani2023predicting,
  title={Predicting transcriptional outcomes of novel multigene perturbations with gears},
  author={Roohani, Yusuf and Huang, Kexin and Leskovec, Jure},
  journal={Nature Biotechnology},
  year={2023},
  publisher={Nature Publishing Group US New York}
}

Paper: Link

Code for reproducing figures: Link

gears's People

Contributors

ekernf01 avatar kexinhuang12345 avatar yhr91 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

gears's Issues

About the preparation of split

Hello, dear developers.
I run the code as below two times, but I think there are something went wrong about the file generation in the split prepare step

pert_data = PertData('./data')
pert_data.load(data_name = 'norman')
pert_data.prepare_split(split = 'simulation', seed = 1)
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)

And the error message is shown as below:
image
So I want to know how to fix that problem. Thanks for your jobs!

Questions about split and the experiment results

Hi !
Thank you for publishing such an excellent paper !

I'm trying to follow your job in Gears, but I have some quesitions due to my limited knowledge, I would appreciate it very much if you could give some help. The questions are as follows:

  1. What does the simulation mean in the split , and what's the difference between simulation_single and single?
  • split [single] use for the dataset[replogle_k562_essential and replogle_rpe1_essential]

  • split [combo_seen0,combo_seen1,combo_seen2] use for the dataset [norman]

    Please correct me if I'm wrong.

  1. Where can I get the detailed experimental result for the figure in your paper, as the paper only provided the gap between Gears and other baselines, without providing specific indicators for them.

Thank you very much for your consideration!

More than 2 perturbations per cell in training data

Hi!
I've been playing with the tool for a some time now and I ran into this: From the paper, I understood that the tool supports any number of perturbations in a single cell. Nevertheless, looking into the code, I found this line:

adata.obs.loc[:, 'dose_val'] = adata.obs.condition.apply(lambda x: '1+1' if len(x.split('+')) == 2 else '1')

which to my understanding means there can only be 1 or 2 perturbations in a single cell in the training dataset. I changed the line to:

adata.obs.loc[:, 'dose_val'] = adata.obs.condition.apply(lambda x: '+'.join('1'*len(x.split('+'))))

to get correct dose strings for cells with more perturbations and everything seems to work.

Is this a reasonable thing to do? And if so, what is the logic with adding ctrl to cells with fewer perturbations? What I mean is that in the notebook, you mention that only wild-type cells should have a single perturbation (ctrl) and cells with 1 perturbation should have the ctrl perturbation added (gene1 + pertrubaton) while cells with 2 perturbations do NOT have the ctrl perturbation added (gene1 + gene2). How should I extend the rule if I have, say, cells with between 0 and 6 perturbations? Thanks!

KeyError: 'val'

Hi, thanks for sharing. I encountered the following error:
b80ee0de2f44c4405f0b19dce31ab14

my adata.obs:
image

Error when trying to load the pre-trained model

Hi! Thank you for providing an opportunity to use your amazing tool!

I faced an issue while running the tutorial for using the trained model. Although I'm loading my own data instead of Norman dataset, I have no problems with creating dataloader and I do not modify any further code. However, I get the error when I load the pre-trained model with load_pretrained. Here is the relevant part of the code:

## Download model from dataverse
dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979956', 'model.zip')

## Extract and set up model directory
with ZipFile(('model.zip'), 'r') as zip:
    zip.extractall(path = './')
model_name = 'gears_misc_umi_no_test'

gears_model = GEARS(pert_data, device = 'cpu',
                        weight_bias_track = False,
                        proj_name = 'gears',
                        exp_name = model_name)
gears_model.load_pretrained('./model_ckpt')

And the error is

TypeError                                 Traceback (most recent call last)
[<ipython-input-32-2b23a346f02a>](https://localhost:8080/#) in <cell line: 7>()
      5                         proj_name = 'gears',
      6                         exp_name = model_name)
----> 7 gears_model.load_pretrained('./model_ckpt')

[/content/GEARS/gears/gears.py](https://localhost:8080/#) in load_pretrained(self, path)
    255 
    256         del config['device'], config['num_genes'], config['num_perts']
--> 257         self.model_initialize(**config)
    258         self.config = config
    259 

TypeError: GEARS.model_initialize() got an unexpected keyword argument 'cell_fitness_pred'

It seems to me that the problem is following: the dictionary config.pkl (which is a part of the trained model directory) contains an item 'cell_fitness_pred': False, while 'cell_fitness_pred' is not among the arguments of model_initialize(). However, simply removing this item from the dictionary gives different error (here 'model_modified_ckpt' is a copy of 'model_ckpt' directory; the only difference is that I removed the mentioned item from config.pkl):

gears_model.load_pretrained('./model_modified_ckpt')
RuntimeError                              Traceback (most recent call last)
[<ipython-input-36-62ddcd5f7a40>](https://localhost:8080/#) in <cell line: 7>()
      5                         proj_name = 'gears',
      6                         exp_name = model_name)
----> 7 gears_model.load_pretrained('./model_modified_ckpt')

1 frames
[/content/GEARS/gears/gears.py](https://localhost:8080/#) in load_pretrained(self, path)
    268             state_dict = new_state_dict
    269 
--> 270         self.model.load_state_dict(state_dict)
    271         self.model = self.model.to(self.device)
    272         self.best_model = self.model

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
   2039 
   2040         if len(error_msgs) > 0:
-> 2041             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2043         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for GEARS_Model:
	Unexpected key(s) in state_dict: "cell_fitness_mlp.network.0.weight", "cell_fitness_mlp.network.0.bias", "cell_fitness_mlp.network.1.weight", "cell_fitness_mlp.network.1.bias", "cell_fitness_mlp.network.1.running_mean", "cell_fitness_mlp.network.1.running_var", "cell_fitness_mlp.network.1.num_batches_tracked", "cell_fitness_mlp.network.3.weight", "cell_fitness_mlp.network.3.bias", "cell_fitness_mlp.network.4.weight", "cell_fitness_mlp.network.4.bias", "cell_fitness_mlp.network.4.running_mean", "cell_fitness_mlp.network.4.running_var", "cell_fitness_mlp.network.4.num_batches_tracked", "cell_fitness_mlp.network.6.weight", "cell_fitness_mlp.network.6.bias", "cell_fitness_mlp.network.7.weight", "cell_fitness_mlp.network.7.bias", "cell_fitness_mlp.network.7.running_mean", "cell_fitness_mlp.network.7.running_var", "cell_fitness_mlp.network.7.num_batches_tracked". 
	size mismatch for indv_w1: copying a param with shape torch.Size([5054, 64, 1]) from checkpoint, the shape in current model is torch.Size([5000, 64, 1]).
	size mismatch for indv_b1: copying a param with shape torch.Size([5054, 1]) from checkpoint, the shape in current model is torch.Size([5000, 1]).
	size mismatch for indv_w2: copying a param with shape torch.Size([1, 5054, 65]) from checkpoint, the shape in current model is torch.Size([1, 5000, 65]).
	size mismatch for indv_b2: copying a param with shape torch.Size([1, 5054]) from checkpoint, the shape in current model is torch.Size([1, 5000]).
	size mismatch for gene_emb.weight: copying a param with shape torch.Size([5054, 64]) from checkpoint, the shape in current model is torch.Size([5000, 64]).
	size mismatch for emb_pos.weight: copying a param with shape torch.Size([5054, 64]) from checkpoint, the shape in current model is torch.Size([5000, 64]).
	size mismatch for cross_gene_state.network.0.weight: copying a param with shape torch.Size([64, 5054]) from checkpoint, the shape in current model is torch.Size([64, 5000]).

Is there some simple way to resolve this issue?

problems with dataloader between versions

Hello!

Thank you for the great package. I have some code that relies on cell-gears=0.0.1. The latest version of cell-gears makes a lot of things much earlier including importing new datasets.

When I update to the current version, it seems like the dataloader only returns X with a single dimension [batch_sizen_genes,1] while earlier it used to return a matrix of size [batch_sizen_genes,2]. The first axis are the expression flags while the second are the perturbation masks. Can you please help me figure out which part of the code changed so that I can figure out how to update to the latest version and still make things work?

Thank you!

Demo notebook raises gradient error for torch 1.12

Hi, I'm using torch 1.12.0 and python 3.9, and when I run the demo notebook https://github.com/snap-stanford/GEARS/blob/master/demo/model_tutorial.ipynb, it raises this error when running gears_model.train.


RuntimeError Traceback (most recent call last)
File :1, in <cell line: 1>()
----> 1 gears_model.train(epochs = 1, lr = 1e-3)

File /anaconda/envs/Gears/lib/python3.9/site-packages/gears/gears.py:298, in GEARS.train(self, epochs, lr, weight_decay)
293 pred = self.model(batch)
294 loss = loss_fct(pred, y, batch.pert,
295 ctrl = self.ctrl_expression,
296 dict_filter = self.dict_filter,
297 direction_lambda = self.config['direction_lambda'])
--> 298 loss.backward()
299 nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0)
300 optimizer.step()

File /anaconda/envs/Gears/lib/python3.9/site-packages/torch/_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
387 if has_torch_function_unary(self):
388 return handle_torch_function(
389 Tensor.backward,
390 (self,),
(...)
394 create_graph=create_graph,
395 inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File /anaconda/envs/Gears/lib/python3.9/site-packages/torch/autograd/init.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
168 retain_graph = create_graph
170 # The reason we repeat same the comment below is that
171 # some Python versions print out the first line of a multi-line function
172 # calls in the traceback and some print out the last line
--> 173 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors
, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [161440, 64]], which is output 0 of ReluBackward0, is at version 76; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

RuntimeError Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 get_ipython().run_line_magic('run', 'gears_perturb.ipynb')

File /anaconda/envs/Gears/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2305, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
2303 kwargs['local_ns'] = self.get_local_scope(stack_depth)
2304 with self.builtin_trap:
-> 2305 result = fn(*args, **kwargs)
2306 return result

File /anaconda/envs/Gears/lib/python3.9/site-packages/IPython/core/magics/execution.py:717, in ExecutionMagics.run(self, parameter_s, runner, file_finder)
715 with preserve_keys(self.shell.user_ns, 'file'):
716 self.shell.user_ns['file'] = filename
--> 717 self.shell.safe_execfile_ipy(filename, raise_exceptions=True)
718 return
720 # Control the response to exit() calls made by the script being run

File /anaconda/envs/Gears/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2811, in InteractiveShell.safe_execfile_ipy(self, fname, shell_futures, raise_exceptions)
2809 result = self.run_cell(cell, silent=True, shell_futures=shell_futures)
2810 if raise_exceptions:
-> 2811 result.raise_error()
2812 elif not result.success:
2813 break

File /anaconda/envs/Gears/lib/python3.9/site-packages/IPython/core/interactiveshell.py:251, in ExecutionResult.raise_error(self)
249 raise self.error_before_exec
250 if self.error_in_exec is not None:
--> 251 raise self.error_in_exec

[... skipping hidden 1 frame]

File :1, in <cell line: 1>()
----> 1 gears_model.train(epochs = 1, lr = 1e-3)

File /anaconda/envs/Gears/lib/python3.9/site-packages/gears/gears.py:298, in GEARS.train(self, epochs, lr, weight_decay)
293 pred = self.model(batch)
294 loss = loss_fct(pred, y, batch.pert,
295 ctrl = self.ctrl_expression,
296 dict_filter = self.dict_filter,
297 direction_lambda = self.config['direction_lambda'])
--> 298 loss.backward()
299 nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0)
300 optimizer.step()

File /anaconda/envs/Gears/lib/python3.9/site-packages/torch/_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
387 if has_torch_function_unary(self):
388 return handle_torch_function(
389 Tensor.backward,
390 (self,),
(...)
394 create_graph=create_graph,
395 inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File /anaconda/envs/Gears/lib/python3.9/site-packages/torch/autograd/init.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
168 retain_graph = create_graph
170 # The reason we repeat same the comment below is that
171 # some Python versions print out the first line of a multi-line function
172 # calls in the traceback and some print out the last line
--> 173 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors
, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [161440, 64]], which is output 0 of ReluBackward0, is at version 76; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I think this is related to some inplace operation in the model.

prediction interpretation

The prediction tutorial notebook generates values between 0 and 5. Are those absolute expression changes? Or expression values post perturbation?
If those are the final expression value, what's the fastest way to retrieve the unperturbed values?

Support for combinations of gene activation and repression

Hi,

Thanks for releasing this package! One addition that would be great would be support for datasets that contain combinations of activations and repressions, including within the same perturbation. My understanding from @yhr91 is that this is not currently supported, but is not complex to add.

Thanks again!

Reproduce figure 2 b-g

Hi, could you provide code to reproduce figure 2b-g in your preprint? Or at least could tell how to rerun the analysis workflow?

Interpretation of Prediction

Hi author,

I have run the tutorial to test out on GEARS model and what I found is that the output of the prediction is not a gene expression in post perturbation and instead it is just a gene vector.

May I know how to obtained the post perturbation gene expression?

Thanks.

Questions on pert_data.prepare_split

Hi Team, I'm using the GEARS package to load my local dataset. I ran the code pert_data.prepare_split(split=split, seed=1) and found that the program add the training and testing message automatically. I wonder what parameter can let more perturb data to the training part rather than the testing or the validation part?
Thanks in advance!

Publication of model parameters

Hi!

Thanks again for making this code public. I have found the hyperparameter settings that you have used to train your model in the supplementary notes of the paper. However, I could not find a link to download the parameters of the trained model which has produced the performance metrics and predictions listed in the paper. I am asking since my training run on Norman et al. data looks fishy, i.e. loss decreases only minimally and then starts oscillating, and using the optimal parameters from the manuscript would greatly increase the confidence in my results and the method.

Since it is good scientific practice to publish the model parameters, is there any chance we can download them? Perhaps there is a link somewhere and I have simply missed it, in that case I would be very greatful if you could post it here.

Thank you

tutorial_inference_Norman.ipynb is out-of-date

When trying to run the tutorial_inference_Norman.ipynb file, loading the pretrained model generates the following error:

TypeError                                 Traceback (most recent call last)
Cell In[111], line 6
      1 gears_model = GEARS(pert_data, 
      2                     device = 'cpu',
      3                     weight_bias_track = False,
      4                     proj_name = 'gears',
      5                     exp_name = 'test')
----> 6 gears_model.load_pretrained('../model/model_ckpt')

File ~/miniconda3/envs/gears/lib/python3.10/site-packages/gears/gears.py:257, in GEARS.load_pretrained(self, path)
    254     config = pickle.load(f)
    256 del config['device'], config['num_genes'], config['num_perts']
--> 257 self.model_initialize(**config)
    258 self.config = config
    260 state_dict = torch.load(os.path.join(path, 'model.pt'), map_location = torch.device('cpu'))

TypeError: GEARS.model_initialize() got an unexpected keyword argument 'cell_fitness_pred'

This is because the model.py file was updated with the commit 8da5e36, which removed the cell_fitness_pred argument, but the model checkpoint provided via
dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979956', './model.zip') does not use the updated model.

A question about the direction loss

Hi, it seems the direction loss is not working since torch.sign blocks the backward of gradient.

https://github.com/snap-stanford/GEARS/blob/master/gears/utils.py#L388

Here is a toy experiment on my local machine with torch version 2.0.0:

import torch
for i in range(-5, 5):
... x = torch.tensor(i, dtype=float, requires_grad=True)
... y = torch.sign(x)
... y.backward()
... print(x.grad)
...
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)
tensor(0., dtype=torch.float64)

Trouble using the outputs obtained from the GEARS model

Hi!

First of all, thanks for producing reproducible code that allows researchers to use your method. However, with such a high-class publication I would have expected some more verbose documentation. I am trying to simulate perturbations of individual genes and am stuck at the last step, evaluating the outputs. So far I couldn't find any mention of it in the tutorials.

After training and loading the model, typing

gears_model.predict([["PARK7"]])

gives me the following output:

{'PARK7': array([9.3185983e-04, 1.2710856e-02, 4.3317977e-02, ..., 3.7023327e+00,
       4.8783151e-03, 2.0390714e-06], dtype=float32)}

Which is of shape (5045,). I understand that this vector should contain the expression changes of 5045 genes. How can I get the info which value corresponds to which gene? This information would be very helpful.

Secondly, if I try to plot the perturbations wit the following line:

gears_model.plot_perturbation("PARK7", "PARK7.png")

I get the following error:

Traceback (most recent call last):
  File "/home/ubuntu/gears/testrun.py", line 34, in <module>
    gears_model.plot_perturbation("PARK7", "PARK7.png")
  File "/mnt/storage/anaconda3/envs/gears/lib/python3.10/site-packages/gears/gears.py", line 441, in plot_perturbation
    adata.uns['top_non_dropout_de_20'][cond2name[query]]]
KeyError: 'PARK7'

PARK7 must be in the gene set, otherwise the previous prediction wouldn't have worked. I also tried PARK7+ctrl and ctrl+PARK7 but got the same error. Any help on how I could get this running?

Directly implementing gears on scRNA-seq data

Hi,
Thank you for publishing such an excellent paper! I'm new to perturb-seq and only have single cell transcriptome data. I wonder whether I can use gears directly on the scRNA-seq data to infer the perturbation. Or I must provide my own perturb-seq data.
Thanks!

Predictions: expression or delta expression?

Hi Yusuf et al., I have a really simple question. Running a small example for 5 epochs, I notice about 20% of the predictions are negative, even though the training data are all nonnegative. Does GEARS predict expression directly, or additive change in expression over the control? Example below.

from gears import PertData, GEARS
pert_data = PertData('./data', default_pert_graph=False)
pert_data.load(data_name = "dixit")
pert_data.prepare_split(split = 'simulation', seed = 5) 
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 
# set up and train a model
gears_model = GEARS(pert_data, device = 'cpu')
gears_model.model_initialize(hidden_size = 64)
gears_model.train(epochs = 5)

# predict
y = gears_model.predict([['USP13', 'USP15'], ['UTP6']])
(list(y.values())[0]>0).mean() # 0.82

error: model_initialize() got an unexpected keyword argument 'cell_fitness_pred' when running "Using Trained Model" tutorial

Hi,

Thank you for your wonderful work! However, when I follow the tutorial "Using Trained Model", I encounter the error

model_initialize() got an unexpected keyword argument 'cell_fitness_pred'

Perhaps this is because of the differences in version of gears. I found that the parameter cell_fitness_pred disappeared in the current version. Can you give some suggestions for this situation?

A question about directional loss.

Thank you for your impressive work, especially some designs on losses. However, I found the sign function used in the directional loss to be non-derivable, resulting in zero gradients. How does this loss work in the final result?

Train/Test GEARS on One Gene Perturbation

Hello,
I'm currently attempting to train, validate, and test GEARS using just a single gene perturbation. Unfortunately, I have not been successful in achieving this setup. From what I understand about the available split options, it seems that the described modeling scenario is not feasible with the current version of GEARS. Could you please confirm whether my understanding is correct?

 available_splits = ['simulation', 'simulation_single', 'combo_seen0', 
	 'combo_seen1', 'combo_seen2', 'single', 'no_test', 
	 'no_split'] 

Thank you.

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation

Hello there,
I have managed to make the package installation work, and while I am trying to reproduce the results in your paper using the following commands:

from gears import PertData, GEARS

# get data
pert_data = PertData('./data')
# load dataset in paper: norman, adamson, dixit.
pert_data.load(data_name = 'norman')
# specify data split
pert_data.prepare_split(split = 'simulation', seed = 1)
# get data loader with batch size
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)

# set up and train a model
gears_model = GEARS(pert_data, device = 'cpu')
gears_model.model_initialize(hidden_size = 64)
gears_model.train(epochs = 20)

When I reach the last command gears_model.train(epochs = 20), the program is throwing the following error:
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. Do you have any idea how to overcome this problem?

Note. I am using cpu as device, I don't have CUDA.

Thanks in advance.

split questions

Hi there, I'm wondering what exactly the "simulation" split means? There are a few different splits that I can't quite infer without a definition.

Also, how are ctrl cells paired with perturbed cells for each example?

Additionally, are the downloaded datasets preprocessed (normalized TPM and log transformed, looks like it)?

Overall, I'm trying to get this information annotated in the pert_data.adata.obs dataframe rather than use the GEARS data representations.

Thanks!

Question about installation

Hello there,
I am facing problems installing the package. I tried installing it in a virtual environment but it is failing. Do you recommend a specific python version? I am following your instructions but the version of CUDA is appear as "None" I don't know why.

Any recommendations and advises are appreciated.

Thanks in advance!

How many epochs should take when training the model using the Replogle K562 essential dataset.

Hi,
thanks for the great job!
Frist, how many epochs should take when training the model using K562 essential and RPE1 datasets? will the larger the performance is better? How many epochs you set in your current study?

Second, at present the perturb-seq dataset is usually small which only with tens of perturbations? So, I want to ask how many perturbations the model usually need to get a satisfactory performance. To put it more clearly, if I train gears using a dataset only with 20s single gene perturbations, will it get a satisfactory performance?

Thanks for your patience!

IndexError: index 0 is out of bounds for axis 0 with size 0

Hi team, I'm using the pert_data.new_data_process and facing some problems:
`IndexError Traceback (most recent call last)
/tmp/ipykernel_1143058/1783411457.py in
1 pert_data = PertData("./data/")
----> 2 pert_data.new_data_process(dataset_name = 'xialab_perturb', adata = tmp_adata)

~/miniconda3/envs/scgpt/lib/python3.7/site-packages/gears/pertdata.py in new_data_process(self, dataset_name, adata)
94 dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl')
95 print_sys("Creating pyg object for each cell in the data...")
---> 96 self.dataset_processed = self.create_dataset_file()
97 print_sys("Saving new dataset pyg object at " + dataset_fname)
98 pickle.dump(self.dataset_processed, open(dataset_fname, "wb"))

~/miniconda3/envs/scgpt/lib/python3.7/site-packages/gears/pertdata.py in create_dataset_file(self)
246 dl = {}
247 for p in tqdm(self.adata.obs['condition'].unique()):
--> 248 cell_graph_dataset = self.create_cell_graph_dataset(self.adata, p, num_samples=1)
249 dl[p] = cell_graph_dataset
250 return dl

~/miniconda3/envs/scgpt/lib/python3.7/site-packages/gears/pertdata.py in create_cell_graph_dataset(self, split_adata, pert_category, num_samples)
283 if pert_category != 'ctrl':
284 # Get the indices of applied perturbation
--> 285 pert_idx = self.get_pert_idx(pert_category, adata_)
286
287 # Store list of genes that are most differentially expressed for testing

~/miniconda3/envs/scgpt/lib/python3.7/site-packages/gears/pertdata.py in get_pert_idx(self, pert_category, adata_)
252 def get_pert_idx(self, pert_category, adata_):
253 pert_idx = [np.where(p == self.gene_names)[0][0]
--> 254 for p in pert_category.split('+')
255 if p != 'ctrl']
256 return pert_idx

~/miniconda3/envs/scgpt/lib/python3.7/site-packages/gears/pertdata.py in (.0)
253 pert_idx = [np.where(p == self.gene_names)[0][0]
254 for p in pert_category.split('+')
--> 255 if p != 'ctrl']
256 return pert_idx
257

IndexError: index 0 is out of bounds for axis 0 with size 0
`

I checked the .obs .var and .X, the type of those are the same as the adamson you provided, but it still came error

Issue with new_data_process function

Hi,

I am trying to input my own data into GEARS, and am encountering an issue after I thought I formatted my scanpy object correctly. I also tried adding the ensembl id as the index to the .var dataframe, however this still triggered the same error. Any solutions to this? FYI I deleted my .raw file because I couldn't save my h5ad file if it wasn't deleted. The formatting in the .raw file is different from the formatting you require in the .obs and .var dataframes.

Here is the line of code that is triggering an error:

pert_data.new_data_process(dataset_name = 'my_data', adata = adata_final) # specific dataset name and adata object

Here is the error being triggered:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[140], [line 2](vscode-notebook-cell:?execution_count=140&line=2)
      [1](vscode-notebook-cell:?execution_count=140&line=1) pert_data = PertData('[./data](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/data)') # specific saved folder
----> [2](vscode-notebook-cell:?execution_count=140&line=2) pert_data.new_data_process(dataset_name = 'my_data', adata = adata_final) # specific dataset name and adata object
      [3](vscode-notebook-cell:?execution_count=140&line=3) # pert_data.load(data_path = './data/my_data') # load the processed data, the path is saved folder + dataset_name
      [4](vscode-notebook-cell:?execution_count=140&line=4) # pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
      [5](vscode-notebook-cell:?execution_count=140&line=5) # pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader

File [~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:250](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:250), in PertData.new_data_process(self, dataset_name, adata, skip_calc_de)
    [248](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:248)     os.mkdir(save_data_folder)
    [249](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:249) self.dataset_path = save_data_folder
--> [250](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:250) self.adata = get_DE_genes(adata, skip_calc_de)
    [251](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:251) if not skip_calc_de:
    [252](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/pertdata.py:252)     self.adata = get_dropout_non_zero_genes(self.adata)

File [~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:64](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:64), in get_DE_genes(adata, skip_calc_de)
     [62](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:62) adata.obs = adata.obs.astype('category')
     [63](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:63) if not skip_calc_de:
---> [64](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:64)     rank_genes_groups_by_cov(adata, 
     [65](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:65)                      groupby='condition_name', 
     [66](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:66)                      covariate='cell_type', 
     [67](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:67)                      control_group='ctrl_1', 
     [68](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:68)                      n_genes=len(adata.var),
     [69](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/gears/data_utils.py:69)                      key_added = 'rank_genes_groups_cov_all')
...
    [109](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:109)     )
    [111](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:111) adata_comp = adata
    [112](https://file+.vscode-resource.vscode-cdn.net/Users/shayecarver/GraphPerturb/~/opt/anaconda3/envs/GraphPerturb/lib/python3.10/site-packages/scanpy/tools/_rank_genes_groups.py:112) if layer is not None:

ValueError: Could not calculate statistics for groups ct1_DLG5-AS1+PPM1D_1+1, ct1_IRF2+TIRAP_1+1,...

Here is the structure of my count matrix:

AnnData object with n_obs × n_vars = 15744 × 5000
    obs: 'condition', 'cell_type'
    var: 'gene_name'
    uns: 'hvg'

Here is the structure of adata_final.var:

     gene_name
0   AL669831.5
1    LINC00115
2       FAM41C
10        HES4
11       ISG15

Here is the structure of adata_final.obs:

                         condition   cell_type
AAACCCAAGTGGCAGT-1_1  SBDS+TNFSF15  ct1
AAACGAACAATGTCTG-1_1     NIPBL-AS1  ct1
AAACGAATCCTCTCTT-1_1          ctrl  ct1
AAACGCTCAAGTTCCA-1_1   ctrl+RAD54B  ct1
AAACGCTCAGGTCAAG-1_1      HLA-DRB6  ct1

traing stack using the v0.10.0 gears

Hi,
thanks for updating the gears to v0.10.0
but when i run this version, it seems stack in training the model. following is the message output by gears.

"
63 Done!
64 Local copy of split is detected. Loading...
65 Simulation split test composition:
66 combo_seen0:0
67 combo_seen1:0
68 combo_seen2:0
69 unseen_single:388
70 Done!
71 Creating dataloaders....
72 Done!
73 Found local copy...
74 Start Training...
75 >envs/gears/lib/python3.9/site-packages/gears/model.py(217)forward()
76 -> return torch.stack(out)
77 (Pdb)

"

Bug in new_data_process

There is an easily fixable bug in line 254 of gears/pertdata.py. The line is currently:

self.dataset_processed = self.create_dataset_file()

However, create_dataset_file() doesn't return anything. Based on other usage of the create_dataset_file()(line 201) the function should just be run, and it will fill in the self.dataset_processed instance variable.

Just to be clear, I believe line 254 should be:

self.create_dataset_file()

Thanks!

Minimal example of data loading fails

Hi GEARS maintainers, I'd like to try this out on some data, but I'm having trouble with the ingestion. Could you help me debug the following minimal example? Currently this yields AttributeError: Can only use .cat accessor with a 'category' dtype. Full traceback below.

import anndata
from gears import PertData as GEARSPertData
from gears import GEARS
import pandas as pd
import numpy as np
import scipy
example_data = anndata.AnnData(
    X = np.random.random((5,5)), 
    obs = pd.DataFrame({
        "index": ["o{i}" for i in range(5)], 
        "cell_type":"all",
        "condition": ["ctrl", "FOXA1+FOXA2", "ctrl+GCM2", "ctrl+HOXA3", "ctrl+PAX9"],
    }),
    var = pd.DataFrame({
        "index":[0,1,2,3,4],
        "gene_name": ["FOXA1", "FOXA2", "GCM2", "HOXA3", "PAX9"],
    })
)
example_data.obs["condition"] = example_data.obs["condition"].astype('category')
example_data.var["gene_name"] = example_data.var["gene_name"].astype('category')
example_data.obs["cell_type"] = example_data.obs["cell_type"].astype('category')
pert_data = GEARSPertData("./gears_input")
example_data.uns["log1p"] = dict()
example_data.uns["log1p"]["base"] = np.exp(1)
example_data.X = scipy.sparse.csr_matrix(example_data.X)
pert_data.new_data_process(dataset_name = 'current', adata = example_data)
pert_data.load(data_path = './gears_input/current')
pert_data.prepare_split(split = 'simulation', seed = 1 )
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)
model = GEARS(pert_data, device = "cpu")
model.model_initialize(hidden_size = 64)
model.train(epochs = 20)

Traceback

 File "/home/ekernf01/Downloads/gears data loader bug.py", line 29, in <module>
    pert_data.new_data_process(dataset_name = 'current', adata = example_data)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/gears/pertdata.py", line 85, in new_data_process
    self.adata = get_DE_genes(adata)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/gears/data_utils.py", line 61, in get_DE_genes
    rank_genes_groups_by_cov(adata, 
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/gears/data_utils.py", line 37, in rank_genes_groups_by_cov
    sc.tl.rank_genes_groups(
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/scanpy/tools/_rank_genes_groups.py", line 572, in rank_genes_groups
    if reference != 'rest' and reference not in adata.obs[groupby].cat.categories:
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/generic.py", line 5989, in __getattr__
    return object.__getattribute__(self, name)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/accessor.py", line 224, in __get__
    accessor_obj = self._accessor(obj)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/arrays/categorical.py", line 2445, in __init__
    self._validate(data)
  File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/arrays/categorical.py", line 2454, in _validate
    raise AttributeError("Can only use .cat accessor with a 'category' dtype")
AttributeError: Can only use .cat accessor with a 'category' dtype

Question: GO-network(s) with hierarchy separation?

Hi,
I noticed that the current GO-network is constructed by using the overlap (Jaccard index) between genes over ALL GO-terms, regardless of GO-hierarchy (biological process, cellular component, molecular function).
This leads to. e.g isoforms having a very high score, since they naturally overlap in all 3 categories together. However, connections by pathways seem to be drowned out, at least in my example:
Even direct partners in the (well researched and basic) glycolysis¹ pathway do not make the cutoff value (0.1), like HK (e.g. HK1) and PGI .
However, you write you tried other networks in place of GO, like "protein-protein interaction network [21], a gene coessentiality network [41] or the gene co-expression network". So I would be interested in your take on this:

Do you think a separation of GO-networks by hierarchy (and maybe some filtering for term size) and the resulting highlighting of direct interaction partners would help the gene-gene transferability of predictions/perturbations?

Thank you!

PS: If you are interested in ~40% faster code (or a lot faster with joblib.Parallel) for the GO network generation, let me know. Can send a merge request.
¹https://en.wikipedia.org/wiki/Glycolysis

Perturbations are not in the GO graph

Hello,
I encountered following issue:
I created my own adata object with 4 perturbed transcription factors:
my perturbations are not in the GO graph. how do I fix this.
Here is the code
from gears import PertData

pert_data = PertData('./data') # specific saved folder
pert_data.new_data_process(dataset_name = 'men', adata = adata) # specific dataset name and adata object
pert_data.load(data_path = './data/men') # load the processed data, the path is saved folder + dataset_name
pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader

Found local copy...
Found local copy...
Creating pyg object for each cell in the data...
Creating dataset file...
100%|██████████| 5/5 [00:25<00:00, 5.13s/it]
Done!
Saving new dataset pyg object at ./data/men/data_pyg/cell_graphs.pkl
Done!
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
[]
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:0
Done!

here1


KeyError Traceback (most recent call last)

in <cell line: 7>()
5 pert_data.load(data_path = './data/men') # load the processed data, the path is saved folder + dataset_name
6 pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
----> 7 pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader

/usr/local/lib/python3.10/dist-packages/gears/pertdata.py in get_dataloader(self, batch_size, test_batch_size)
453 for i in splits:
454 cell_graphs[i] = []
--> 455 for p in self.set2conditions[i]:
456 cell_graphs[i].extend(self.dataset_processed[p])
457

KeyError: 'val'

Very long import time

The package gears takes too much time to import. (__version__ = '0.0.3')

$time python -c 'from gears import PertData'

real    1m37.423s
user    0m14.010s
sys     0m5.412s
$time python -c 'import torch'

real    0m18.923s
user    0m2.117s
sys     0m1.732s

A question about selecting genes

Hi, thanks for your great work. I have a question about DE genes.

image

It seems that in the evaluation process, 20 genes are selected for evaluation. I wonder how to generate these genes. Are they based on statistical test? Or ground truth? Thanks a lot.

Issue loading custom data

So, I was trying to load some scPerturb data, I transformed the data to have required columns (and more). But I am getting this error:

image

I wasn't getting this error earlier, as even the PyG object generation started earlier, not sure, why I am getting this now.

Any help would be appreciated!

Thanks

Problem in loading a subset of Dixit dataset

Hi,

I am having a trouble loading a subset of the dixit dataset and get the following error. I try to follow the tutorial provided on https://github.com/snap-stanford/GEARS/blob/master/demo/data_tutorial.ipynb, but have not been able to figure out the issue! Any guidance would be highly appreciated!

    304 gears_pert_data.prepare_split(split = 'simulation', seed = 1) # get data split with seed
    305 gears_pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader
pertdata.py in get_dataloader(self, batch_size, test_batch_size)
    306             for i in splits:
    307                 cell_graphs[i] = []
    308                 for p in self.set2conditions[i]:
    309                     cell_graphs[i].extend(self.dataset_processed[p])
    310 

KeyError: 'val'

Thanks a lot!

Any suggestions for the HPC requirement

Really appreciate the delicate project and codes. I am running the code and find it stuck in the preprocessing phase due to the lack of RAM. My workstation is 128G RAM with 4 RTX 3090. Could you help suggest the minimum requirement for processing the data, e.g., norman dataset?

Thanks!

perturbation coexpression

Hello,

Thanks for developing the method. While I am using this, I have one question regarding the networks used for latent embeddings. Correct me if I am wrong, my understanding is that for each gene the coexpression graph is based on control whereas GO graph is used for perturbations. I wonder whether perturbation based coexpression graph is also used for latent embedding in addition to GO graph. If it's not included, is it due to the poor performance or generalization to unseen pert?

In addition, is it possible to have additional embeddings for the composition operation by plugin some customized network ie. ppi or grn? Not sure if this will improve the prediction.

Thank you!

Questions about reproduction of Figure2cd

Hi,
Thank you for publishing such an excellent paper! I'm trying to reproduce the results of Figure2c and 2d. In the GEARS_misc/paper/Fig2cd.ipynb, I noticed that the color_pal is directly input. Could you please provide any guidance on how to reproduce these data and how could I get the perturbation results for all genes?
Forgive me for my inexperience, thank you and have a nice day!
image

Inference of a new gene

I am have installed gears and trained and saved the model. I ma trying ti use it for inference, as it was suggeste din README.using this example:

predict

gears_model.predict([['FOX1A', 'AHR'], ['FEV']])
gears_model.GI_predict([['FOX1A', 'AHR'], ['FEV', 'AHR']])
I am getting this error:
ValueError: The gene is not in the perturbation graph. Please select from GEARS.gene_list!
After checking the GEARS.gene_list i noticed that FOX1A is not but the other two are in the list.
In the paper, it is mentioned that even if a gene is not in training we should still get prediction of that gene.
So could you please let me know what I am missing?

no_test option for data split yields error

Hi Yusuf et al., congratulations on publishing GEARS -- the paper is excellent and impressively thorough.

Can GEARS use all of the input data for training or validation, leaving nothing for a test set? Running as recommended seems to create a train-val-test split with 25% in the test fold. I am using GEARS in a setting where the test data have already been set aside, and foregoing another 25% of the remaining data could potentially make a big difference to performance. I notice there is an option 'no_test' implemented in data_utils L174, but I get an error when I use it -- full example below. I am using gears version 0.0.4.

Example code:

from gears import PertData, GEARS
dataset_name = 'adamson'
pert_data = PertData('./data', default_pert_graph=False)
pert_data.load(data_name = dataset_name)
pert_data.prepare_split(split = 'no_test', seed = 5) 
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 

The error:

File "/home/ekernf01/mambaforge/envs/ggrn/lib/python3.9/site-packages/gears/pertdata.py", line 308, in get_dataloader
for p in self.set2conditions[i]:
    KeyError: 'val'

How to define pertubation?

Thanks for your great work. I am writing to ask how to define the perturbation in GEARS study. Does it mean knockout or knockdown? For example, can GEARS predict the transcription profile after perturbing genes A and B simultaneously, where gene A expression decreases by 10% (e.g. 10 to 9) and gene B expression decreases by 20% (20 to 16)?

Questions about the implementation

Thanks for your great contribution!

I have two questions about the implementation.
1,

https://github.com/snap-stanford/GEARS/blob/master/gears/model.py#L146-L151
If the batch size = 2 and len(list(pert_track.values()))==2, single perturbation will be fed into MLP.
I also find not all the training examples in one batch have two perturbations.

2,

https://github.com/snap-stanford/GEARS/blob/master/gears/model.py#L139-L144
As the 'j' in these lines is a tensor, not an integer, this line
pert_track[j] = pert_track[j] + pert_global_emb[pert_index[1][i]]
is not executed during training process.
So these two perturbations for one sample are not added together?

ValueError: must produce aggregated value

Hi Yusuf et al., I ran into this error while preparing the data split, and I cannot figure out how to proceed. Is there a way to fix the input data? Thank you for your help and let me know if you need other info -- here are the details.

OS: Ubuntu 20.04
Python: 3.9.15
GEARS: 0.0.4
Traceback:

pert_data.prepare_split(split = 'simulation' )
Creating new splits....
Traceback (most recent call last):
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/groupby/generic.py", line 260, in aggregate
    return self._python_agg_general(
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/groupby/groupby.py", line 1085, in _python_agg_general
    result, counts = self.grouper.agg_series(obj, f)
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/groupby/ops.py", line 644, in agg_series
    return self._aggregate_series_pure_python(obj, func)
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/groupby/ops.py", line 717, in _aggregate_series_pure_python
    raise ValueError("Function does not reduce")
ValueError: Function does not reduce

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/gears/pertdata.py", line 264, in prepare_split
    set2conditions = dict(adata.obs.groupby('split').agg({'condition':
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/groupby/generic.py", line 951, in aggregate
    result, how = self._aggregate(func, *args, **kwargs)
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/base.py", line 416, in _aggregate
    result = _agg(arg, _agg_1dim)
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/base.py", line 383, in _agg
    result[fname] = func(fname, agg_how)
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/base.py", line 367, in _agg_1dim
    return colg.aggregate(how)
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/groupby/generic.py", line 267, in aggregate
    result = self._aggregate_named(func, *args, **kwargs)
  File "/home/ubuntu/mambaforge/envs/ggrn/lib/python3.9/site-packages/pandas/core/groupby/generic.py", line 482, in _aggregate_named
    raise ValueError("Must produce aggregated value")
ValueError: Must produce aggregated value

To reproduce:

wget https://www.dropbox.com/s/yvyty40vvr7j7h8/test.h5ad
import scanpy as sc
import numpy as np
import scipy
import shutil
from gears import PertData as GEARSPertData
from gears import GEARS
train = sc.read_h5ad("test.h5ad")
def reformat_perturbation_for_gears(perturbation):
    perturbation = perturbation.split(",")
    perturbation = ["ctrl"] + [p if p in train.var.index else "ctrl" for p in perturbation]
    perturbation = list(set(perturbation)) # avoid ctrl+ctrl
    perturbation = "+".join(perturbation)
    return perturbation


train.var['gene_name'] = train.var.index
train.obs['condition'] = [reformat_perturbation_for_gears(p) for p in train.obs['perturbation']]
train.obs['cell_type'] = "all"
for k in train.obs.columns:
    train.obs[k] = train.obs[k].astype("str") 


train.uns["log1p"] = dict()
train.uns["log1p"]["base"] = np.exp(1)
train.uns["perturbed_and_measured_genes"] = list(train.uns["perturbed_and_measured_genes"])
train.uns["perturbed_but_not_measured_genes"] = list(train.uns["perturbed_but_not_measured_genes"])
train.X = scipy.sparse.csr_matrix(train.X)
# Clean up any previous runs
try:
    shutil.rmtree("./gears_input")
except FileNotFoundError:
    pass


pert_data = GEARSPertData("./gears_input", default_pert_graph = True)
pert_data.new_data_process(dataset_name = 'current', adata = train)
pert_data.load(data_path = './gears_input/current')
pert_data.prepare_split(split = 'simulation', seed = 1 )

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.