Comments (1)
I get it working well with another dataset! But the problem is that when I work with the dataset above, I met an error that
#train the scme model
model=scme.train_model(model,max_epochs=10)---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
229 try:
--> 230 log_p = site["fn"].log_prob(
231 site["value"], *site["args"], **site["kwargs"]
232 )
233 except ValueError as e:
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/independent.py:99, in Independent.log_prob(self, value)
98 def log_prob(self, value):
---> 99 log_prob = self.base_dist.log_prob(value)
100 return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/distributions/zero_inflated.py:71, in ZeroInflatedDistribution.log_prob(self, value)
70 if self._validate_args:
---> 71 self._validate_sample(value)
73 if "gate" in self.__dict__:
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/distribution.py:300, in Distribution._validate_sample(self, value)
299 if not valid.all():
--> 300 raise ValueError(
301 "Expected value argument "
302 f"({type(value).__name__} of shape {tuple(value.shape)}) "
303 f"to be within the support ({repr(support)}) "
304 f"of the distribution {repr(self)}, "
305 f"but found invalid values:\n{value}"
306 )
ValueError: Expected value argument (Tensor of shape (256, 2000)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([256, 2000])), but found invalid values:
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 1.1207, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 1.0324, 1.0324, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 1.1781, 1.1781, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[74], line 2
1 #train the scme model
----> 2 model=scme.train_model(model,max_epochs=10)
File ~/Workspace/Multiomics/benchmark/script/./scME/scme.py:136, in train_model(model, max_epochs, batchsize, lr, lr_cla, milestones, save_model, save_dir)
134 yr=yr.to(device)
135 yp=yp.to(device)
--> 136 loss1=svi.step(rna,protein,yr,yp)
137 loss2=svi2.step(rna,protein,yr,yp)
138 losses_ae.append(loss1)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs)
235 else:
236 for i in range(self.num_particles):
--> 237 yield self._get_trace(model, guide, args, kwargs)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
72 guide_trace = prune_subsample_sites(guide_trace)
73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
76 guide_trace.compute_score_parts()
77 if is_validation_enabled():
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:236, in Trace.compute_log_prob(self, site_filter)
234 _, exc_value, traceback = sys.exc_info()
235 shapes = self.format_shapes(last_site=site["name"])
--> 236 raise ValueError(
237 "Error while computing log_prob at site '{}':\n{}\n{}".format(
238 name, exc_value, shapes
239 )
240 ).with_traceback(traceback) from e
241 site["unscaled_log_prob"] = log_p
242 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
228 if "log_prob" not in site:
229 try:
--> 230 log_p = site["fn"].log_prob(
231 site["value"], *site["args"], **site["kwargs"]
232 )
233 except ValueError as e:
234 _, exc_value, traceback = sys.exc_info()
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/independent.py:99, in Independent.log_prob(self, value)
98 def log_prob(self, value):
---> 99 log_prob = self.base_dist.log_prob(value)
100 return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/pyro/distributions/zero_inflated.py:71, in ZeroInflatedDistribution.log_prob(self, value)
69 def log_prob(self, value):
70 if self._validate_args:
---> 71 self._validate_sample(value)
73 if "gate" in self.__dict__:
74 gate, value = broadcast_all(self.gate, value)
File ~/miniconda3/envs/multiomics/lib/python3.9/site-packages/torch/distributions/distribution.py:300, in Distribution._validate_sample(self, value)
298 valid = support.check(value)
299 if not valid.all():
--> 300 raise ValueError(
301 "Expected value argument "
302 f"({type(value).__name__} of shape {tuple(value.shape)}) "
303 f"to be within the support ({repr(support)}) "
304 f"of the distribution {repr(self)}, "
305 f"but found invalid values:\n{value}"
306 )
ValueError: Error while computing log_prob at site 'rna_count':
Expected value argument (Tensor of shape (256, 2000)) to be within the support (IntegerGreaterThan(lower_bound=0)) of the distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([256, 2000])), but found invalid values:
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 1.1207, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 1.0324, 1.0324, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 1.1781, 1.1781, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])
It was fixed after turning the data into integer by
#create training dataset
rna.X=rna.layers["counts"].astype(int)
protein.X=protein.layers["counts"].astype(int)
# rna.X=rna.layers["counts"]
# protein.X=protein.layers["counts"]
traindataset=scme.AnnDataset(rna,protein,to_onehot=True)
which caused new error as I described above.
from scme.
Related Issues (1)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from scme.