From 0fec18f36fe5de4ed6f3dd212c26a767bf9c630f Mon Sep 17 00:00:00 2001 From: jgabry Date: Sat, 20 Jun 2020 15:39:57 -0400 Subject: [PATCH 1/4] add extractor methods for CmdStanMCMC objects --- NAMESPACE | 4 ++ R/bayesplot-extractors.R | 69 +++++++++++++++++-- R/mcmc-diagnostics-nuts.R | 8 +-- man/bayesplot-extractors.Rd | 18 ++++- tests/testthat/test-extractors.R | 31 ++++++++- tests/testthat/test-mcmc-nuts.R | 4 +- .../testthat/test-mcmc-scatter-and-parcoord.R | 2 +- 7 files changed, 118 insertions(+), 18 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 2f5505cf..5bf2b8bd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,10 +2,13 @@ S3method("[",neff_ratio) S3method("[",rhat) +S3method(log_posterior,CmdStanMCMC) S3method(log_posterior,stanfit) S3method(log_posterior,stanreg) +S3method(neff_ratio,CmdStanMCMC) S3method(neff_ratio,stanfit) S3method(neff_ratio,stanreg) +S3method(nuts_params,CmdStanMCMC) S3method(nuts_params,list) S3method(nuts_params,stanfit) S3method(nuts_params,stanreg) @@ -15,6 +18,7 @@ S3method(pp_check,default) S3method(print,bayesplot_function_list) S3method(print,bayesplot_grid) S3method(print,bayesplot_scheme) +S3method(rhat,CmdStanMCMC) S3method(rhat,stanfit) S3method(rhat,stanreg) export(abline_01) diff --git a/R/bayesplot-extractors.R b/R/bayesplot-extractors.R index 72128b74..7fd81004 100644 --- a/R/bayesplot-extractors.R +++ b/R/bayesplot-extractors.R @@ -1,9 +1,9 @@ #' Extract quantities needed for plotting from model objects #' #' Generics and methods for extracting quantities needed for plotting from -#' various types of model objects. Currently methods are only provided for -#' stanfit (**rstan**) and stanreg (**rstanarm**) objects, but adding new -#' methods should be relatively straightforward. +#' various types of model objects. Currently methods are provided for stanfit +#' (**rstan**), CmdStanMCMC (**cmdstanr**), and stanreg (**rstanarm**) objects, +#' but adding new methods should be relatively straightforward. #' #' @name bayesplot-extractors #' @param object The object to use. @@ -87,7 +87,8 @@ log_posterior.stanfit <- function(object, inc_warmup = FALSE, ...) { ...) lp <- lapply(lp, as.array) lp <- set_names(reshape2::melt(lp), c("Iteration", "Value", "Chain")) - validate_df_classes(lp, c("integer", "numeric", "integer")) + validate_df_classes(lp[, c("Chain", "Iteration", "Value")], + c("integer", "integer", "numeric")) } #' @rdname bayesplot-extractors @@ -98,11 +99,22 @@ log_posterior.stanreg <- function(object, inc_warmup = FALSE, ...) { log_posterior.stanfit(object$stanfit, inc_warmup = inc_warmup, ...) } +#' @rdname bayesplot-extractors +#' @export +#' @method log_posterior CmdStanMCMC +log_posterior.CmdStanMCMC <- function(object, inc_warmup = FALSE, ...) { + lp <- object$draws("lp__", inc_warmup = inc_warmup) + lp <- reshape2::melt(lp) + lp$variable <- NULL + lp <- dplyr::rename_with(lp, capitalize_first) + validate_df_classes(lp[, c("Chain", "Iteration", "Value")], + c("integer", "integer", "numeric")) +} + #' @rdname bayesplot-extractors #' @export #' @method nuts_params stanfit -#' nuts_params.stanfit <- function(object, pars = NULL, @@ -153,7 +165,23 @@ nuts_params.list <- function(object, pars = NULL, ...) { out <- reshape2::melt(object) out <- set_names(out, c("Iteration", "Parameter", "Value", "Chain")) - validate_df_classes(out, c("integer", "factor", "numeric", "integer")) + validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")], + c("integer", "integer", "factor", "numeric")) +} + +#' @rdname bayesplot-extractors +#' @export +#' @method nuts_params CmdStanMCMC +nuts_params.CmdStanMCMC <- function(object, pars = NULL, ...) { + arr <- object$sampler_diagnostics() + if (!is.null(pars)) { + arr <- arr[,, pars] + } + out <- reshape2::melt(arr) + colnames(out)[colnames(out) == "variable"] <- "parameter" + out <- dplyr::rename_with(out, capitalize_first) + validate_df_classes(out[, c("Chain", "Iteration", "Parameter", "Value")], + c("integer", "integer", "factor", "numeric")) } @@ -188,6 +216,16 @@ rhat.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) { r[!names(r) %in% c("mean_PPD", "log-posterior")] } +#' @rdname bayesplot-extractors +#' @export +#' @method rhat CmdStanMCMC +rhat.CmdStanMCMC <- function(object, pars = NULL, ...) { + s <- object$summary(pars, rhat = ~ posterior::rhat(.x))[, c("variable", "rhat")] + r <- setNames(s$rhat, s$variable) + r <- validate_rhat(r) + r[!names(r) %in% "lp__"] +} + #' @rdname bayesplot-extractors #' @export @@ -223,6 +261,18 @@ neff_ratio.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) { ratio[!names(ratio) %in% c("mean_PPD", "log-posterior")] } +#' @rdname bayesplot-extractors +#' @export +#' @method neff_ratio CmdStanMCMC +neff_ratio.CmdStanMCMC <- function(object, pars = NULL, ...) { + s <- object$summary(pars, "n_eff" = "ess_basic")[, c("variable", "n_eff")] + ess <- setNames(s$n_eff, s$variable) + tss <- prod(dim(object$draws())[1:2]) + ratio <- ess / tss + ratio <- validate_neff_ratio(ratio) + ratio[!names(ratio) %in% "lp__"] +} + # internals --------------------------------------------------------------- @@ -245,3 +295,10 @@ validate_df_classes <- function(x, classes = character()) { } x } + +# capitalize first letter in a string only +capitalize_first <- function(name) { + name <- tolower(name) # in case whole string is capitalized + substr(name, 1, 1) <- toupper(substr(name, 1, 1)) + name +} diff --git a/R/mcmc-diagnostics-nuts.R b/R/mcmc-diagnostics-nuts.R index 766826b3..cbaccdfa 100644 --- a/R/mcmc-diagnostics-nuts.R +++ b/R/mcmc-diagnostics-nuts.R @@ -513,8 +513,8 @@ validate_nuts_data_frame <- function(x, lp) { abort("NUTS parameters should be in a data frame.") } - valid_cols <- c("Iteration", "Parameter", "Value", "Chain") - if (!identical(colnames(x), valid_cols)) { + valid_cols <- sort(c("Iteration", "Parameter", "Value", "Chain")) + if (!identical(sort(colnames(x)), valid_cols)) { abort(paste( "NUTS parameter data frame must have columns:", paste(valid_cols, collapse = ", ") @@ -529,8 +529,8 @@ validate_nuts_data_frame <- function(x, lp) { abort("lp should be in a data frame.") } - valid_lp_cols <- c("Iteration", "Value", "Chain") - if (!identical(colnames(lp), valid_lp_cols)) { + valid_lp_cols <- sort(c("Iteration", "Value", "Chain")) + if (!identical(sort(colnames(lp)), valid_lp_cols)) { abort(paste( "lp data frame must have columns:", paste(valid_lp_cols, collapse = ", ") diff --git a/man/bayesplot-extractors.Rd b/man/bayesplot-extractors.Rd index bc96a18c..593a2b0e 100644 --- a/man/bayesplot-extractors.Rd +++ b/man/bayesplot-extractors.Rd @@ -8,13 +8,17 @@ \alias{neff_ratio} \alias{log_posterior.stanfit} \alias{log_posterior.stanreg} +\alias{log_posterior.CmdStanMCMC} \alias{nuts_params.stanfit} \alias{nuts_params.stanreg} \alias{nuts_params.list} +\alias{nuts_params.CmdStanMCMC} \alias{rhat.stanfit} \alias{rhat.stanreg} +\alias{rhat.CmdStanMCMC} \alias{neff_ratio.stanfit} \alias{neff_ratio.stanreg} +\alias{neff_ratio.CmdStanMCMC} \title{Extract quantities needed for plotting from model objects} \usage{ log_posterior(object, ...) @@ -29,19 +33,27 @@ neff_ratio(object, ...) \method{log_posterior}{stanreg}(object, inc_warmup = FALSE, ...) +\method{log_posterior}{CmdStanMCMC}(object, inc_warmup = FALSE, ...) + \method{nuts_params}{stanfit}(object, pars = NULL, inc_warmup = FALSE, ...) \method{nuts_params}{stanreg}(object, pars = NULL, inc_warmup = FALSE, ...) \method{nuts_params}{list}(object, pars = NULL, ...) +\method{nuts_params}{CmdStanMCMC}(object, pars = NULL, ...) + \method{rhat}{stanfit}(object, pars = NULL, ...) \method{rhat}{stanreg}(object, pars = NULL, regex_pars = NULL, ...) +\method{rhat}{CmdStanMCMC}(object, pars = NULL, ...) + \method{neff_ratio}{stanfit}(object, pars = NULL, ...) \method{neff_ratio}{stanreg}(object, pars = NULL, regex_pars = NULL, ...) + +\method{neff_ratio}{CmdStanMCMC}(object, pars = NULL, ...) } \arguments{ \item{object}{The object to use.} @@ -80,9 +92,9 @@ Methods return (named) vectors. } \description{ Generics and methods for extracting quantities needed for plotting from -various types of model objects. Currently methods are only provided for -stanfit (\strong{rstan}) and stanreg (\strong{rstanarm}) objects, but adding new -methods should be relatively straightforward. +various types of model objects. Currently methods are provided for stanfit +(\strong{rstan}), CmdStanMCMC (\strong{cmdstanr}), and stanreg (\strong{rstanarm}) objects, +but adding new methods should be relatively straightforward. } \examples{ \dontrun{ diff --git a/tests/testthat/test-extractors.R b/tests/testthat/test-extractors.R index 302defc1..a57f7b23 100644 --- a/tests/testthat/test-extractors.R +++ b/tests/testthat/test-extractors.R @@ -42,7 +42,7 @@ test_that("all nuts_params methods identical", { test_that("nuts_params.stanreg returns correct structure", { np <- nuts_params(fit) - expect_identical(colnames(np), c("Iteration", "Parameter", "Value", "Chain")) + expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value")) np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog", "divergent", "energy"), "__") @@ -54,7 +54,7 @@ test_that("nuts_params.stanreg returns correct structure", { test_that("log_posterior.stanreg returns correct structure", { lp <- log_posterior(fit) - expect_identical(colnames(lp), c("Iteration", "Value", "Chain")) + expect_identical(colnames(lp), c("Chain", "Iteration", "Value")) expect_equal(length(unique(lp$Iteration)), floor(ITER / 2)) expect_equal(length(unique(lp$Chain)), CHAINS) }) @@ -100,3 +100,30 @@ test_that("neff_ratio.stanreg returns correct structure", { ans2 <- summary(fit, pars = c("wt", "sigma"))[, "n_eff"] / denom expect_equal(ratio2, ans2, tol = 0.001) }) + +test_that("cmdstanr method work", { + skip_on_cran() + skip_if_not_installed("cmdstanr") + + fit <- cmdstanr::cmdstanr_example("logistic", iter_sampling = 500, chains = 2) + np <- nuts_params(fit) + np_names <- paste0(c("accept_stat", "stepsize", "treedepth", "n_leapfrog", + "divergent", "energy"), "__") + expect_identical(levels(np$Parameter), np_names) + expect_equal(range(np$Iteration), c(1, 500)) + expect_equal(range(np$Chain), c(1, 2)) + expect_true(all(np$Value[np$Parameter == "divergent__"] == 0)) + + lp <- log_posterior(fit) + expect_named(lp, c("Iteration", "Chain", "Value")) + expect_equal(range(np$Chain), c(1, 2)) + expect_equal(range(np$Iteration), c(1, 500)) + + r <- rhat(fit) + expect_named(r, c("alpha", "beta[1]", "beta[2]", "beta[3]")) + expect_true(all(round(r) == 1)) + + ratio <- neff_ratio(fit) + expect_named(ratio, c("alpha", "beta[1]", "beta[2]", "beta[3]")) + expect_true(all(ratio < 1) && all(ratio > 0)) +}) diff --git a/tests/testthat/test-mcmc-nuts.R b/tests/testthat/test-mcmc-nuts.R index 16b4aa14..7b19e8a3 100644 --- a/tests/testthat/test-mcmc-nuts.R +++ b/tests/testthat/test-mcmc-nuts.R @@ -58,7 +58,7 @@ test_that("validate_nuts_data_frame throws errors", { ) expect_error( validate_nuts_data_frame(data.frame(Iteration = 1, apple = 2)), - "NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain" + "NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value" ) expect_error( validate_nuts_data_frame(np, as.matrix(lp)), @@ -69,7 +69,7 @@ test_that("validate_nuts_data_frame throws errors", { colnames(lp2)[3] <- "Chains" expect_error( validate_nuts_data_frame(np, lp2), - "lp data frame must have columns: Iteration, Value, Chain" + "lp data frame must have columns: Chain, Iteration, Value" ) lp2 <- subset(lp, Chain %in% 1:2) diff --git a/tests/testthat/test-mcmc-scatter-and-parcoord.R b/tests/testthat/test-mcmc-scatter-and-parcoord.R index af7a3cd4..fe22b6ee 100644 --- a/tests/testthat/test-mcmc-scatter-and-parcoord.R +++ b/tests/testthat/test-mcmc-scatter-and-parcoord.R @@ -311,7 +311,7 @@ test_that("mcmc_parcoord throws correct warnings and errors", { expect_error( mcmc_parcoord(post, np = np[, -1]), - "NUTS parameter data frame must have columns: Iteration, Parameter, Value, Chain", + "NUTS parameter data frame must have columns: Chain, Iteration, Parameter, Value", fixed = TRUE ) From d0ad9c35633fcdbe467622434f3f5f7411b3c5cd Mon Sep 17 00:00:00 2001 From: jgabry Date: Sat, 20 Jun 2020 17:40:22 -0400 Subject: [PATCH 2/4] fix test --- R/bayesplot-extractors.R | 3 ++- tests/testthat/test-extractors.R | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/R/bayesplot-extractors.R b/R/bayesplot-extractors.R index 7fd81004..60067500 100644 --- a/R/bayesplot-extractors.R +++ b/R/bayesplot-extractors.R @@ -220,7 +220,8 @@ rhat.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) { #' @export #' @method rhat CmdStanMCMC rhat.CmdStanMCMC <- function(object, pars = NULL, ...) { - s <- object$summary(pars, rhat = ~ posterior::rhat(.x))[, c("variable", "rhat")] + .rhat <- getFromNamespace("rhat", "posterior") + s <- object$summary(pars, rhat = .rhat)[, c("variable", "rhat")] r <- setNames(s$rhat, s$variable) r <- validate_rhat(r) r[!names(r) %in% "lp__"] diff --git a/tests/testthat/test-extractors.R b/tests/testthat/test-extractors.R index a57f7b23..022821c8 100644 --- a/tests/testthat/test-extractors.R +++ b/tests/testthat/test-extractors.R @@ -101,7 +101,7 @@ test_that("neff_ratio.stanreg returns correct structure", { expect_equal(ratio2, ans2, tol = 0.001) }) -test_that("cmdstanr method work", { +test_that("cmdstanr methods work", { skip_on_cran() skip_if_not_installed("cmdstanr") @@ -115,7 +115,7 @@ test_that("cmdstanr method work", { expect_true(all(np$Value[np$Parameter == "divergent__"] == 0)) lp <- log_posterior(fit) - expect_named(lp, c("Iteration", "Chain", "Value")) + expect_named(lp, c("Chain", "Iteration", "Value")) expect_equal(range(np$Chain), c(1, 2)) expect_equal(range(np$Iteration), c(1, 500)) @@ -125,5 +125,5 @@ test_that("cmdstanr method work", { ratio <- neff_ratio(fit) expect_named(ratio, c("alpha", "beta[1]", "beta[2]", "beta[3]")) - expect_true(all(ratio < 1) && all(ratio > 0)) + expect_true(all(ratio > 0)) }) From cf379ee5f2305a67d455b3944e46b8f9e592f282 Mon Sep 17 00:00:00 2001 From: jgabry Date: Sat, 20 Jun 2020 18:10:40 -0400 Subject: [PATCH 3/4] Update bayesplot-extractors.R --- R/bayesplot-extractors.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/bayesplot-extractors.R b/R/bayesplot-extractors.R index 60067500..19c6df88 100644 --- a/R/bayesplot-extractors.R +++ b/R/bayesplot-extractors.R @@ -220,7 +220,7 @@ rhat.stanreg <- function(object, pars = NULL, regex_pars = NULL, ...) { #' @export #' @method rhat CmdStanMCMC rhat.CmdStanMCMC <- function(object, pars = NULL, ...) { - .rhat <- getFromNamespace("rhat", "posterior") + .rhat <- utils::getFromNamespace("rhat", "posterior") s <- object$summary(pars, rhat = .rhat)[, c("variable", "rhat")] r <- setNames(s$rhat, s$variable) r <- validate_rhat(r) From 376acccbf5d8127fe3010c9054bbd41b55b38d04 Mon Sep 17 00:00:00 2001 From: jgabry Date: Fri, 7 Aug 2020 09:36:31 -0600 Subject: [PATCH 4/4] Update NEWS.md --- NEWS.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/NEWS.md b/NEWS.md index 4bf6d25a..51c32232 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,10 @@ * Items for next release go here --> +* CmdStanMCMC objects (from CmdStanR) can now be used with extractor + functions `nuts_params()`, `log_posterior()`, `rhat()`, and + `neff_ratio()`. (#227) + * Added missing `facet_args` argument to `mcmc_rank_overlay()`. (#221, @hhau)