Comments (12)
Another solution to the problem would be to save the logp
as part of a Sample
struct, as suggested in #5 (comment).
from dynamicppl.jl.
I might be too sleepy to follow this discussion but I am not sure how Sample
can help solve this problem. I am talking about allowing the user to write threaded code in the model body and accumulating the logp correctly with no race conditions.
from dynamicppl.jl.
With the Sample
struct you wouldn't have to accumulate the logp while sampling, you just save it together with the sample and the distribution it was sampled from and accumulate it after stepping through the model. Doesn't this solve the issue?
from dynamicppl.jl.
Well not really as far as I can tell, because the point is to make the logpdf
computation and accumulation of iid variables threaded. So I think it is easiest to enable the users to do this inside the model body using Threads.@threads
loop blocks for example or other threading patterns. It is the user's responsibility to ensure the thread safety of the model body but it is our responsibility to ensure the thread-safety of the ~
lines. Does this make sense?
from dynamicppl.jl.
OK, if I understand correctly, you want that users could write
@model model(x) = begin
m ~ Normal(0, 1)
y = Vector{Float64}(undef, length(x))
Threads.@threads for i in 1:length(x)
y[i] ~ InverseGamma(2, 3)
x[i] ~ Normal(m, y[i])
end
end
where x
would be some vector-valued data. In that case it wouldn't be sufficient to just change how/when logp is accumulated, since also pushing samples to the VarInfo object is not thread safe.
If you only have observations, i.e., something similar to
@model model(x) = begin
m ~ Normal(0, 1)
Threads.@threads for i in 1:length(x)
x[i] ~ Normal(m, 1)
end
end
one could currently write something like (untested)
@model model(x) = begin
m ~ Normal(0, 1)
dist = Normal(m, 1)
logps = zeros(Threads.nthreads())
Threads.@threads for i in 1:length(x)
logp = DynamicPPL.tilde_observe(
_context, _sampler, dist, x[i], @varname(x[i]), @vinds(x[i]), _varinfo
)
logps[Threads.threadid()] += logp
end
acclogp!(_varinfo, sum(logps))
end
by exploiting the internals. I guess your intention was to tackle the second example on the level of VarInfo?
So the advantage of the Sample approach would be that at the same time one could also support the first example - instead of just having separate containers of logp for each thread one would have separate containers of samples for each thread.
from dynamicppl.jl.
BTW: I noticed a while ago that probably making logp
(or the samples) a vector depending on the number of threads is not a viable solution. The problem would be that VarInfo
objects would be dependent on the number of threads that were used when creating them and could not be used for rerunning the model with a different number of threads. Maybe a better approach would be to modify the evaluation function to
function evaluator(
_model::DynamicPPL.Model,
_varinfo::DynamicPPL.VarInfo,
_sampler::DynamicPPL.AbstractSampler,
_context::DynamicPPL.AbstractContext,
)
_logp = $(contains_at_threads ? zeros(Threads.nthreads()) : Ref(0.0))
out = innerevaluator(_model, _varinfo, _sampler, _context, _logp)
acclogp!(_varinfo, sum(_logp))
return out
end
function innerevaluator(
_model::Model,
_varinfo::VarInfo,
_sampler::AbstractSampler,
_context::AbstractContext,
_logp
)
$unwrapdataexpr
$mainbody
end
and then add only to _logp
, either the thread-specific one if _logp
is an array and to the common one otherwise, in innerevaluator
.
from dynamicppl.jl.
The main issue with this proposal is that _varinfo.logp[]
will not be updated at any one point in the model but only at the end. This is bad if I want to print the value at different points in the model for debugging.
from dynamicppl.jl.
I guess you could make _logp
a reserved name as well and the user can be told to print that instead of _varinfo.logp[]
.
from dynamicppl.jl.
And we don't need to check if the model contains @threads
. We can have a type parameter or Val
field in Sampler
for the number of threads. This can then be used to define an MArray
-like object for the logps keeping false sharing in mind. The reason to avoid checking @threads
is that there are multiple ways to do multi-threading in Julia and the user may even define their own macro which expands to @threads
or @spawn
.
from dynamicppl.jl.
I guess your intention was to tackle the second example on the level of VarInfo?
Yes.
from dynamicppl.jl.
The problem would be that VarInfo objects would be dependent on the number of threads that were used when creating them and could not be used for rerunning the model with a different number of threads.
I think that's fine in general. We can also modify the size of the logps
vector in the VarInfo to match the number of threads at the beginning of the model call or when constructing the Sampler
.
from dynamicppl.jl.
I guess you could make _logp a reserved name as well and the user can be told to print that instead of _varinfo.logp[].
Yes, that's what I would suggest for debugging.
And we don't need to check if the model contains @threads
We don't have to do that. I thought it might be a nice optimization (in particular if you want to run the model multiple times) but I'm not so sure anymore.
from dynamicppl.jl.
Related Issues (20)
- Remove `NamedDist` in favour of `VarName` interpolation HOT 1
- Simplify `assume`/`observe` design HOT 1
- Error with `.~` and `rand` HOT 2
- Supporting mutating ADs in models that fill arrays of parameters HOT 7
- Roadmap for depreciating `VarInfo` in favour of `SimpleVarInfo` HOT 5
- Type instability: assigning slices in assumptions HOT 2
- "[DynamicPPL] attempt to link a linked vi" warning when aborting sampling and returning minus infinity HOT 6
- Tests fail on Julia 1.8 HOT 2
- Name clash caused by submodels is hard to debug HOT 3
- `TypedVarInfo` failing for certain models over empty vectors HOT 1
- Remove use of `threadid` HOT 9
- Models with dynamic dimensionality
- Possibly confusing `.~` meaning HOT 6
- WARNING: Method definition subsumes [...] overwritten HOT 4
- Support for linking distributions with embedded support HOT 10
- Use merge queue instead of bors? HOT 1
- InferenceObjects integration HOT 12
- Adding StatsBase.predict to the API HOT 7
- `LogDensityFunction`: Temporary variable is captured as a model parameter? HOT 3
- Conditioning with Turing Chains `name_map` HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dynamicppl.jl.