Skip to content

Commit

Permalink
Automatically init model methods, add inc_warmup argument to `uncon…
Browse files Browse the repository at this point in the history
…strain_draws()` method (#985)

* Automatically init model methods when not compiled

* Automatically init model methods, add inc_warmup argument to unconstrain_draws
  • Loading branch information
andrjohns authored May 24, 2024
1 parent 1679aa7 commit a79cc5e
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 105 deletions.
81 changes: 31 additions & 50 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,12 @@ CmdStanFit$set("public", name = "init", value = init)
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' }
#' @seealso [log_prob()], [grad_log_prob()], [constrain_variables()],
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#' [hessian()]
#'
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
init_model_methods <- function(seed = 1, verbose = FALSE, hessian = FALSE) {
if (os_is_wsl()) {
stop("Additional model methods are not currently available with ",
"WSL CmdStan and will not be compiled",
Expand All @@ -348,11 +347,12 @@ init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
"which is still experimental. Please report any compilation ",
"errors that you encounter")
}
message("Compiling additional model methods...")
if (is.null(private$model_methods_env_$model_ptr)) {
expose_model_methods(private$model_methods_env_, verbose, hessian)
}
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
if (!("model_ptr_" %in% ls(private$model_methods_env_))) {
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
}
invisible(NULL)
}
CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods)
Expand All @@ -372,7 +372,6 @@ CmdStanFit$set("public", name = "init_model_methods", value = init_model_methods
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$log_prob(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
#' }
#'
Expand All @@ -385,10 +384,7 @@ log_prob <- function(unconstrained_variables, jacobian = TRUE, jacobian_adjustme
warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
jacobian <- jacobian_adjustment
}
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
self$init_model_methods()
if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(unconstrained_variables), " were provided!", call. = FALSE)
Expand All @@ -410,7 +406,6 @@ CmdStanFit$set("public", name = "log_prob", value = log_prob)
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$grad_log_prob(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
#' }
#'
Expand All @@ -423,10 +418,7 @@ grad_log_prob <- function(unconstrained_variables, jacobian = TRUE, jacobian_adj
warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
jacobian <- jacobian_adjustment
}
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
self$init_model_methods()
if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(unconstrained_variables), " were provided!", call. = FALSE)
Expand Down Expand Up @@ -461,10 +453,7 @@ hessian <- function(unconstrained_variables, jacobian = TRUE, jacobian_adjustmen
warning("'jacobian_adjustment' is deprecated. Please use 'jacobian' instead.", call. = FALSE)
jacobian <- jacobian_adjustment
}
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
self$init_model_methods()
if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
stop("Model has ", private$model_methods_env_$num_upars_, " unconstrained parameter(s), but ",
length(unconstrained_variables), " were provided!", call. = FALSE)
Expand All @@ -487,7 +476,6 @@ CmdStanFit$set("public", name = "hessian", value = hessian)
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$unconstrain_variables(list(alpha = 0.5, beta = c(0.7, 1.1, 0.2)))
#' }
#'
Expand All @@ -496,10 +484,7 @@ CmdStanFit$set("public", name = "hessian", value = hessian)
#' [hessian()]
#'
unconstrain_variables <- function(variables) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}
self$init_model_methods()
model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
prov_par_names <- names(variables)

Expand Down Expand Up @@ -539,11 +524,12 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
#' @param draws A `posterior::draws_*` object.
#' @param format (string) The format of the returned draws. Must be a valid
#' format from the \pkg{posterior} package.
#' @param inc_warmup (logical) Should warmup draws be included? Defaults to
#' `FALSE`.
#'
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#'
#' # Unconstrain all internal draws
#' unconstrained_internal_draws <- fit_mcmc$unconstrain_draws()
Expand All @@ -560,7 +546,9 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
#' [hessian()]
#'
unconstrain_draws <- function(files = NULL, draws = NULL,
format = getOption("cmdstanr_draws_format", "draws_array")) {
format = getOption("cmdstanr_draws_format", "draws_array"),
inc_warmup = FALSE) {
self$init_model_methods()
if (!(format %in% valid_draws_formats())) {
stop("Invalid draws format requested!", call. = FALSE)
}
Expand All @@ -570,22 +558,25 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
call. = FALSE)
}
if (!is.null(files)) {
read_csv <- read_cmdstan_csv(files = files, format = "draws_matrix")
draws <- read_csv$post_warmup_draws
}
if (!is.null(draws)) {
draws <- maybe_convert_draws_format(draws, "draws_matrix")
}
} else {
if (is.null(private$draws_)) {
if (!length(self$output_files(include_failed = FALSE))) {
stop("Fitting failed. Unable to retrieve the draws.", call. = FALSE)
read_csv <- read_cmdstan_csv(files = files)
if (inc_warmup) {
draws <- posterior::bind_draws(read_csv$warmup_draws,
read_csv$post_warmup_draws,
along = "iteration")
} else {
draws <- read_csv$post_warmup_draws
}
} else if (!is.null(draws)) {
if (inc_warmup) {
message("'inc_warmup' cannot be used with a draws object. Ignoring.")
}
private$read_csv_(format = "draws_df")
}
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
} else {
draws <- self$draws(inc_warmup = inc_warmup)
}

draws <- maybe_convert_draws_format(draws, "draws_matrix")

chains <- posterior::nchains(draws)

model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
Expand Down Expand Up @@ -624,7 +615,6 @@ CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$variable_skeleton()
#' }
#'
Expand All @@ -633,11 +623,7 @@ CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
#' [hessian()]
#'
variable_skeleton <- function(transformed_parameters = TRUE, generated_quantities = TRUE) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}

self$init_model_methods()
create_skeleton(private$model_methods_env_$param_metadata_,
self$runset$args$model_variables,
transformed_parameters,
Expand All @@ -662,7 +648,6 @@ CmdStanFit$set("public", name = "variable_skeleton", value = variable_skeleton)
#' @examples
#' \dontrun{
#' fit_mcmc <- cmdstanr_example("logistic", method = "sample", force_recompile = TRUE)
#' fit_mcmc$init_model_methods()
#' fit_mcmc$constrain_variables(unconstrained_variables = c(0.5, 1.2, 1.1, 2.2))
#' }
#'
Expand All @@ -671,12 +656,8 @@ CmdStanFit$set("public", name = "variable_skeleton", value = variable_skeleton)
#' [hessian()]
#'
constrain_variables <- function(unconstrained_variables, transformed_parameters = TRUE,
generated_quantities = TRUE) {
if (is.null(private$model_methods_env_$model_ptr)) {
stop("The method has not been compiled, please call `init_model_methods()` first",
call. = FALSE)
}

generated_quantities = TRUE) {
self$init_model_methods()
skeleton <- self$variable_skeleton(transformed_parameters, generated_quantities)

if (length(unconstrained_variables) != private$model_methods_env_$num_upars_) {
Expand Down
7 changes: 6 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,9 @@ rcpp_source_stan <- function(code, env, verbose = FALSE, ...) {
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
if (rlang::is_interactive()) {
message("Compiling additional model methods...")
}
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))
Expand Down Expand Up @@ -1034,7 +1037,9 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE)
})
}
} else {
message("Compiling standalone functions...")
if (rlang::is_interactive()) {
message("Compiling standalone functions...")
}
compile_functions(function_env, verbose, global)
}
invisible(NULL)
Expand Down
1 change: 0 additions & 1 deletion man/fit-method-constrain_variables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/fit-method-grad_log_prob.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions man/fit-method-init_model_methods.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/fit-method-log_prob.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions man/fit-method-unconstrain_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/fit-method-unconstrain_variables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/fit-method-variable_skeleton.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a79cc5e

Please sign in to comment.