Giter Club home page Giter Club logo

Comments (12)

devmotion avatar devmotion commented on June 1, 2024

Another solution to the problem would be to save the logp as part of a Sample struct, as suggested in #5 (comment).

from dynamicppl.jl.

mohamed82008 avatar mohamed82008 commented on June 1, 2024

I might be too sleepy to follow this discussion but I am not sure how Sample can help solve this problem. I am talking about allowing the user to write threaded code in the model body and accumulating the logp correctly with no race conditions.

from dynamicppl.jl.

devmotion avatar devmotion commented on June 1, 2024

With the Sample struct you wouldn't have to accumulate the logp while sampling, you just save it together with the sample and the distribution it was sampled from and accumulate it after stepping through the model. Doesn't this solve the issue?

from dynamicppl.jl.

mohamed82008 avatar mohamed82008 commented on June 1, 2024

Well not really as far as I can tell, because the point is to make the logpdf computation and accumulation of iid variables threaded. So I think it is easiest to enable the users to do this inside the model body using Threads.@threads loop blocks for example or other threading patterns. It is the user's responsibility to ensure the thread safety of the model body but it is our responsibility to ensure the thread-safety of the ~ lines. Does this make sense?

from dynamicppl.jl.

devmotion avatar devmotion commented on June 1, 2024

OK, if I understand correctly, you want that users could write

@model model(x) = begin
	m ~ Normal(0, 1)
	y = Vector{Float64}(undef, length(x))
	Threads.@threads for i in 1:length(x)
		y[i] ~ InverseGamma(2, 3)
		x[i] ~ Normal(m, y[i])
	end
end

where x would be some vector-valued data. In that case it wouldn't be sufficient to just change how/when logp is accumulated, since also pushing samples to the VarInfo object is not thread safe.

If you only have observations, i.e., something similar to

@model model(x) = begin
	m ~ Normal(0, 1)
	Threads.@threads for i in 1:length(x)
		x[i] ~ Normal(m, 1)
	end
end

one could currently write something like (untested)

@model model(x) = begin
	m ~ Normal(0, 1)
	dist = Normal(m, 1)
	logps = zeros(Threads.nthreads())
 	Threads.@threads for i in 1:length(x)
		logp = DynamicPPL.tilde_observe(
			_context, _sampler, dist, x[i], @varname(x[i]), @vinds(x[i]), _varinfo
		)
		logps[Threads.threadid()] += logp
	end
	acclogp!(_varinfo, sum(logps))
end

by exploiting the internals. I guess your intention was to tackle the second example on the level of VarInfo?

So the advantage of the Sample approach would be that at the same time one could also support the first example - instead of just having separate containers of logp for each thread one would have separate containers of samples for each thread.

from dynamicppl.jl.

devmotion avatar devmotion commented on June 1, 2024

BTW: I noticed a while ago that probably making logp (or the samples) a vector depending on the number of threads is not a viable solution. The problem would be that VarInfo objects would be dependent on the number of threads that were used when creating them and could not be used for rerunning the model with a different number of threads. Maybe a better approach would be to modify the evaluation function to

function evaluator(
    _model::DynamicPPL.Model,
    _varinfo::DynamicPPL.VarInfo,
    _sampler::DynamicPPL.AbstractSampler,
    _context::DynamicPPL.AbstractContext,
)
	_logp = $(contains_at_threads ? zeros(Threads.nthreads()) : Ref(0.0))
	out = innerevaluator(_model, _varinfo, _sampler, _context, _logp)
	acclogp!(_varinfo, sum(_logp))
	return out
end

function innerevaluator(
	_model::Model,
	_varinfo::VarInfo,
	_sampler::AbstractSampler,
	_context::AbstractContext,
	_logp
)
	$unwrapdataexpr
	$mainbody
end

and then add only to _logp, either the thread-specific one if _logp is an array and to the common one otherwise, in innerevaluator.

from dynamicppl.jl.

mohamed82008 avatar mohamed82008 commented on June 1, 2024

The main issue with this proposal is that _varinfo.logp[] will not be updated at any one point in the model but only at the end. This is bad if I want to print the value at different points in the model for debugging.

from dynamicppl.jl.

mohamed82008 avatar mohamed82008 commented on June 1, 2024

I guess you could make _logp a reserved name as well and the user can be told to print that instead of _varinfo.logp[].

from dynamicppl.jl.

mohamed82008 avatar mohamed82008 commented on June 1, 2024

And we don't need to check if the model contains @threads. We can have a type parameter or Val field in Sampler for the number of threads. This can then be used to define an MArray-like object for the logps keeping false sharing in mind. The reason to avoid checking @threads is that there are multiple ways to do multi-threading in Julia and the user may even define their own macro which expands to @threads or @spawn.

from dynamicppl.jl.

mohamed82008 avatar mohamed82008 commented on June 1, 2024

I guess your intention was to tackle the second example on the level of VarInfo?

Yes.

from dynamicppl.jl.

mohamed82008 avatar mohamed82008 commented on June 1, 2024

The problem would be that VarInfo objects would be dependent on the number of threads that were used when creating them and could not be used for rerunning the model with a different number of threads.

I think that's fine in general. We can also modify the size of the logps vector in the VarInfo to match the number of threads at the beginning of the model call or when constructing the Sampler.

from dynamicppl.jl.

devmotion avatar devmotion commented on June 1, 2024

I guess you could make _logp a reserved name as well and the user can be told to print that instead of _varinfo.logp[].

Yes, that's what I would suggest for debugging.

And we don't need to check if the model contains @threads

We don't have to do that. I thought it might be a nice optimization (in particular if you want to run the model multiple times) but I'm not so sure anymore.

from dynamicppl.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.