Giter Club home page Giter Club logo

Comments (12)

pwinskill avatar pwinskill commented on August 17, 2024 1

Yes, was just looking at the same thing. I'll have a chat with Bob to check there isn't a quick fix.

from drjacoby.

pwinskill avatar pwinskill commented on August 17, 2024 1

Ok. Seems to be a RNG issue leading to issues when proposing a new theta value after calling the hybrid log_likelihood function. Similar issue here.
Working on testing a fix on this branch

from drjacoby.

jameshay218 avatar jameshay218 commented on August 17, 2024

Maybe helpful to see the C++ function:

// [[Rcpp::export]]
NumericVector likelihood_fast(const NumericVector &expected, const NumericVector &obs, const NumericVector &pars){
    int total_cts = expected.size();
    NumericVector ret(total_cts);
    const double sd = pars["error"];
    const double max_ct = 40.0;
    const double log_const = log(0.5);
    const double den = sd*M_SQRT2;
    
    for(int i = 0; i < total_cts; ++i){
        if(obs[i] < max_ct && obs[i] >= 0.0){
            ret[i] = -log(sd*sqrt(2*M_PI)) - 0.5*pow((obs[i] - expected[i])/sd, 2);
        } else if(obs[i] >= max_ct) {
            ret[i] = log_const + log(erfc((max_ct - expected[i])/den));
        } else {
            ret[i] = log_const + log(1.0 + erf((0.0 - expected[i])/den));
        }
    }
    return(ret);
} 

from drjacoby.

jameshay218 avatar jameshay218 commented on August 17, 2024

Actually small update, I tried just adding a dummy Rcpp function call in the likelihood and this also breaks it:

// [[Rcpp::export]]
double ct_likelihood_fast_tmp(double x){
    return(x);
} 

from drjacoby.

pwinskill avatar pwinskill commented on August 17, 2024

Hi @jameshay218,

Interesting!...

Can I check first that, when specifying the cpp likelihood it is with the same file format defined by the function: drjacoby::cpp_template()? You'll note that the template has an associated SEXP create_xptr() function that is vital to get everything to play nicely.

from drjacoby.

jameshay218 avatar jameshay218 commented on August 17, 2024

Sorry I realize I was unclear from my first message -- this is just called from within the main likelihood function, which is in R eg., as if I were substituting the call to dnorm with james_custom_dnorm_cpp. So the majority of the overall likelihood function is in R. Here is a minimal example based on the vignette:

library(drjacoby)
library(Rcpp)


cppFunction("NumericVector dnorm_cpp(const NumericVector &expected, const double &obs, const NumericVector &pars){
    int n_ret = expected.size();
    NumericVector ret(n_ret);
    const double sd = pars[0];
    const double max_ct = 40.0;
    const double log_const = log(0.5);
    const double den = sd*M_SQRT2;
    
    for(int i = 0; i < n_ret; ++i){
        ret[i] = -log(sd*sqrt(2*M_PI)) - 0.5*pow((obs - expected[i])/sd, 2);
    }
    return(ret);
}")

set.seed(1)

# define true parameter values
mu_true <- 3
sigma_true <- 2

# draw example data
data_list <- list(x = rnorm(10, mean = mu_true, sd = sigma_true))

# define parameters dataframe
df_params <- define_params(name = "mu", min = -10, max = 10,
                           name = "sigma", min = 0, max = Inf)

# define log-likelihood function
r_loglike <- function(params, data, misc) {
    
    # extract parameter values
    mu <- params["mu"]
    sigma <- params["sigma"]
    
    # calculate log-probability of data
    ret <- sum(dnorm(data$x, mean = mu, sd = sigma, log = TRUE))
    
    # return
    return(ret)
}

# define log-likelihood function
r_loglike_cpp <- function(params, data, misc) {
    
    # extract parameter values
    mu <- params["mu"]
    sigma <- params["sigma"]
    
    # calculate log-probability of data
    ret <- sum(dnorm_cpp(data$x, mu, sigma))
    
    # return
    return(ret)
}


# define log-prior function
r_logprior <- function(params, misc) {
    
    # extract parameter values
    mu <- params["mu"]
    sigma <- params["sigma"]
    
    # calculate log-prior
    ret <- dunif(mu, min = -10, max = 10, log = TRUE) +
        dlnorm(sigma, meanlog = 0, sdlog = 1.0, log = TRUE)
    
    # return
    return(ret)
}

## Run example
mcmc <- run_mcmc(data = data_list,
                 df_params = df_params,
                 loglike = r_loglike,
                 logprior = r_logprior,
                 burnin = 1e3,
                 samples = 1e3,
                 pb_markdown = TRUE)
plot_par(mcmc, show="mu")

## Run example with custom Cpp dnorm
mcmc_cpp <- run_mcmc(data = data_list,
                 df_params = df_params,
                 loglike = r_loglike_cpp,
                 logprior = r_logprior,
                 burnin = 1e3,
                 samples = 1e3,
                 pb_markdown = TRUE)
plot_par(mcmc_cpp, show="mu")

Interestingly, I also get an issue when I use extraDistr::dhnorm in the prior function.

from drjacoby.

pwinskill avatar pwinskill commented on August 17, 2024

Thanks James that is helpful.
It's not something we've tried I don't think, but is obviously not happy! I'll make an issue to look into this, at very least it would be nice if it threw up a warning.
All our testing has been with an "all in R" or "all in c++" approach. Using the drjacoby::cpp_template() you can split out functions to call from the overall likelihood (so here you could define a loglike_cpp() which called dnorm_cpp()) but they'd both need to be in the same c++ source file.
As you say, interesting that extraDistr::dhnorm works. I expect if you wrapped your dnorm_cpp() into a proper package and loaded that it might work too, but that seems a bit of a faff.

from drjacoby.

jameshay218 avatar jameshay218 commented on August 17, 2024

Thanks Pete.

It's also interesting that if I print out the return value of dnorm_cpp in r_loglike_cpp, it shows the correct values. So the function is running correctly, suggesting that calling an Rcpp function is interfering with the drjacoby C++ environment (maybe some memory issues, or something odd with jumping from C++->R->C++->R->C++ each iteration). Seems like a totally reasonable limitation if there's no easy fix -- all in R or all in C++. I was hoping to hybridize because it is an ODE model, and I wasn't aware of any quick ways to put the whole thing in C++ without learning a whole new pipeline.

from drjacoby.

pwinskill avatar pwinskill commented on August 17, 2024

Still working on this.
Posting a very simple reprex here for reference:

library(drjacoby)

# Source an empty cpp function
Rcpp::cppFunction("void break_mcmc(){}", verbose=TRUE)

# define log-likelihood function
ll <- function(params, data, misc) {
  # Option to call the empty cpp function
  if(misc$broken){
    break_mcmc()
  }
  # calculate log-probability of data
  ret <- sum(dnorm(data$x, mean = params["mu"], sd = 2, log = TRUE))
  # return
  return(ret)
}

# define log-prior function
lp <- function(params, misc) {
  return(0)
}

# define true parameter values
mu_true <- 3
sigma_true <- 2
# draw example data
set.seed(1234)
data_list <- list(x = rnorm(10, mean = mu_true, sd = sigma_true))
# define parameters dataframe
df_params <- define_params(name = "mu", min = -10, max = 10, init = 3)

# Working
set.seed(1234)
mcmc_cpp <- run_mcmc(data = data_list,
                     df_params = df_params,
                     loglike = ll,
                     logprior = lp,
                     burnin = 500,
                     samples = 500,
                     chains = 1,
                     silent = TRUE,
                     misc = list(broken = FALSE))
# Returns MCMC chain for mu that looks fine:
plot(mcmc_cpp$output$mu, t = "l")

# Not working
set.seed(1234)
mcmc_cpp2 <- run_mcmc(data = data_list,
                     df_params = df_params,
                     loglike = ll,
                     logprior = lp,
                     burnin = 500,
                     samples = 500,
                     chains = 1,
                     silent = TRUE,
                     misc = list(broken = TRUE))
# Returns MCMC chain for mu that does not look fine (either monotonically increases or decreases):
plot(mcmc_cpp2$output$mu, t = "l")

from drjacoby.

pwinskill avatar pwinskill commented on August 17, 2024

With further reading, an almost identical bug found for a MCMC applciation detailed here

from drjacoby.

pwinskill avatar pwinskill commented on August 17, 2024

The branch mentioned above seems to now work for your simple reprex @jameshay218. Would it be possible for you to check using your full, more complex, failing case at some point?
remotes::install_github("mrc-ide/drjacoby@bug/rng")

from drjacoby.

pwinskill avatar pwinskill commented on August 17, 2024

For completeness adding this link detailing the functions used in this fix: GetRNGstate() and PutRNGstate()

from drjacoby.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.