From 19d2ff7c4ede60fe07cd08d754cd4f6965b91c57 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 17 Jan 2024 17:28:32 +0200 Subject: [PATCH 01/23] updating pareto functions for weighted rvars --- R/pareto_smooth.R | 109 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 88 insertions(+), 21 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 7da95a75..295b2ee2 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -46,21 +46,57 @@ pareto_khat.default <- function(x, #' @rdname pareto_khat #' @export -pareto_khat.rvar <- function(x, ...) { - draws_diags <- summarise_rvar_by_element_with_chains( - x, - pareto_smooth.default, - return_k = TRUE, - smooth_draws = FALSE, - ... - ) - dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) - margins <- seq_along(dim(draws_diags)) +pareto_khat.rvar <- function(x, verbose = FALSE, ...) { + if (is.null(weights(x))) { + draws_diags <- summarise_rvar_by_element_with_chains( + x, + pareto_smooth.default, + return_k = TRUE, + smooth_draws = FALSE, + verbose = verbose, + ... + ) - diags <- list( - khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) - ) + dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) + margins <- seq_along(dim(draws_diags)) + + diags <- list( + khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) + ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_khat( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_khat, + ... + ) + + print(weights_diags) + print(product_diags) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, + function(x) { + max(x[[1]]$khat, + weights_diags$khat) + }) + ) + } diags } @@ -138,7 +174,7 @@ pareto_diags.default <- function(x, extra_diags = TRUE, verbose = verbose, smooth_draws = FALSE, - are_log_weights = FALSE, + are_log_weights = are_log_weights, ...) return(smoothed$diagnostics) @@ -149,6 +185,8 @@ pareto_diags.default <- function(x, #' @rdname pareto_diags #' @export pareto_diags.rvar <- function(x, ...) { + + if (is.null(weights(x))) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, @@ -167,6 +205,35 @@ pareto_diags.rvar <- function(x, ...) { khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat_threshold), convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$convergence_rate) ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_diags( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_diags, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, function(x) max(x[[1]]$khat, weights_diags$khat)), + min_ss = apply(product_diags, margins, function(x) max(x[[1]]$min_ss, weights_diags$min_ss)), + khat_threshold = apply(product_diags, margins, function(x) max(x[[1]]$khat_threshold, weights_diags$khat_threshold)), + convergence_rate = apply(product_diags, margins, function(x) min(x[[1]]$convergence_rate, weights_diags$convergence_rate)) + ) + } diags } @@ -279,7 +346,7 @@ pareto_smooth.default <- function(x, if (are_log_weights) { tail <- "right" } - + tail <- match.arg(tail) S <- length(x) @@ -330,7 +397,7 @@ pareto_smooth.default <- function(x, k <- max(left_k, right_k) x <- smoothed$x - + } else { smoothed <- .pareto_smooth_tail( @@ -443,7 +510,7 @@ pareto_convergence_rate.rvar <- function(x, ...) { # shift log values for safe exponentiation x <- x - max(x) } - + tail <- match.arg(tail) S <- length(x) @@ -457,10 +524,10 @@ pareto_convergence_rate.rvar <- function(x, ...) { draws_tail <- ord$x[tail_ids] cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values - + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -616,7 +683,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- NULL if (!are_weights) { - + if (khat > 1) { msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { @@ -629,7 +696,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { } } else { if (khat > khat_threshold || khat > 0.7) { - msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") + msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message("Pareto k-hat = ", round(khat, 2), ".", msg) From cb44b4a487be241233a934bd6aed546ce866065c Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 28 Feb 2024 11:39:33 +0200 Subject: [PATCH 02/23] start on weighted convergence --- DESCRIPTION | 2 +- R/convergence.R | 30 +++++++++++++++++++++++------- R/summarise_draws.R | 2 +- man/posterior-package.Rd | 30 ++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 9 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 6efe8639..ce462430 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -56,5 +56,5 @@ LazyData: false URL: https://mc-stan.org/posterior/, https://discourse.mc-stan.org/ BugReports: https://github.com/stan-dev/posterior/issues Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.0 VignetteBuilder: knitr diff --git a/R/convergence.R b/R/convergence.R index cf895f07..be5b1b1a 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -182,7 +182,7 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export -ess_bulk.default <- function(x, ...) { +ess_bulk.default <- function(x, weights = NULL, ...) { .ess(z_scale(.split_chains(x))) } @@ -319,14 +319,19 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @rdname ess_quantile #' @export -ess_mean.default <- function(x, ...) { - .ess(.split_chains(x)) +ess_mean.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + .ess(.split_chains(x)) + } else { + .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, ...) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) } #' Effective sample size for the standard deviation @@ -449,14 +454,18 @@ mcse_mean <- function(x, ...) UseMethod("mcse_mean") #' @rdname mcse_mean #' @export -mcse_mean.default <- function(x, ...) { - sd(x) / sqrt(ess_mean(x)) +mcse_mean.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + sd(x) / sqrt(ess_mean(x)) + } else { + .mcse_mean_weighted(x, weights, ...) + } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, ...) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) } #' Monte Carlo standard error for the standard deviation @@ -782,6 +791,12 @@ fold_draws <- function(x) { ess } +.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 6 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + # should NA be returned by a convergence diagnostic? should_return_NA <- function(x, tol = .Machine$double.eps) { if (anyNA(x) || checkmate::anyInfinite(x)) { @@ -801,3 +816,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } + diff --git a/R/summarise_draws.R b/R/summarise_draws.R index 6f13755a..05acd83b 100644 --- a/R/summarise_draws.R +++ b/R/summarise_draws.R @@ -329,9 +329,9 @@ empty_draws_summary <- function(dimensions = "variable") { create_summary_list <- function(x, v, funs, .args) { draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE) - args <- c(list(draws), .args) v_summary <- named_list(names(funs)) for (m in names(funs)) { + args <- c(list(draws), .args[[m]]) v_summary[[m]] <- do.call(funs[[m]], args) } v_summary diff --git a/man/posterior-package.Rd b/man/posterior-package.Rd index 0f933f84..5396bf16 100644 --- a/man/posterior-package.Rd +++ b/man/posterior-package.Rd @@ -73,3 +73,33 @@ causes a warning can be controlled by this option. } } +\seealso{ +Useful links: +\itemize{ + \item \url{https://mc-stan.org/posterior/} + \item \url{https://discourse.mc-stan.org/} + \item Report bugs at \url{https://github.com/stan-dev/posterior/issues} +} + +} +\author{ +\strong{Maintainer}: Paul-Christian Bürkner \email{paul.buerkner@gmail.com} + +Authors: +\itemize{ + \item Jonah Gabry \email{jsg2201@columbia.edu} + \item Matthew Kay \email{mjskay@northwestern.edu} + \item Aki Vehtari \email{Aki.Vehtari@aalto.fi} +} + +Other contributors: +\itemize{ + \item Måns Magnusson [contributor] + \item Rok Češnovar [contributor] + \item Ben Lambert [contributor] + \item Ozan Adıgüzel [contributor] + \item Jacob Socolar [contributor] + \item Noa Kallioinen [contributor] +} + +} From 165fb4fd35fafdbe678a83138e2ec3365a1e52fc Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 7 Mar 2024 11:59:53 +0200 Subject: [PATCH 03/23] improvements to weighted ess, mcse --- R/convergence.R | 108 ++++++++++++++++++++++++++++++++++++---------- R/pareto_smooth.R | 5 --- 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 270020b5..c3b6e348 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -183,13 +183,18 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export ess_bulk.default <- function(x, weights = NULL, ...) { - .ess(z_scale(.split_chains(x))) + if (is.null(weights)) { + .ess(z_scale(.split_chains(x))) + } else { + .ess_weighted(x, weights, ...) + } } #' @rdname ess_bulk #' @export ess_bulk.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_bulk, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_bulk, weights = weights, ...) } #' Tail effective sample size (tail-ESS) @@ -220,16 +225,17 @@ ess_tail <- function(x, ...) UseMethod("ess_tail") #' @rdname ess_tail #' @export -ess_tail.default <- function(x, ...) { - q05_ess <- ess_quantile(x, 0.05) - q95_ess <- ess_quantile(x, 0.95) +ess_tail.default <- function(x, weights = NULL, ...) { + q05_ess <- ess_quantile(x, 0.05, weights = weights, ...) + q95_ess <- ess_quantile(x, 0.95, weights = weights, ...) min(q05_ess, q95_ess) } #' @rdname ess_tail #' @export ess_tail.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_tail, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_tail, weights = weights, ...) } #' Effective sample sizes for quantiles @@ -258,13 +264,17 @@ ess_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname ess_quantile #' @export -ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .ess_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .ess_quantile, x = x) + } else { + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + } if (names) { names(out) <- paste0("ess_q", probs * 100) } @@ -274,7 +284,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { #' @rdname ess_quantile #' @export ess_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_quantile, probs, weights = weights, names, ...) } #' @rdname ess_quantile @@ -297,6 +308,19 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } +.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { + if (should_return_NA(x)) { + return(NA_real_) + } + x <- as.matrix(x) + if (prob == 1) { + len <- length(x) + prob <- (len - 0.5) / len + } + I <- x <= weighted_quantile(x, prob, weights) + .ess_weighted(I, weights = weights, r_eff = r_eff) +} + #' Effective sample size for the mean #' #' Compute an effective sample size estimate for a mean (expectation) @@ -324,14 +348,15 @@ ess_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + .ess_weighted(x, weights, ...) } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights, ...) } #' Effective sample size for the standard deviation @@ -358,14 +383,19 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @rdname ess_sd #' @export -ess_sd.default <- function(x, ...) { - .ess(.split_chains(abs(x-mean(x)))) +ess_sd.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(.split_chains(abs(x-mean(x)))) + } else { + .ess_weighted(abs(x - mean(x)), weights = weights, ...) + } } #' @rdname ess_sd #' @export ess_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } #' Monte Carlo standard error for quantiles @@ -394,23 +424,29 @@ mcse_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname mcse_quantile #' @export -mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .mcse_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .mcse_quantile, x = x) + } else { + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + } if (names) { names(out) <- paste0("mcse_q", probs * 100) } + out } #' @rdname mcse_quantile #' @export mcse_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, weights = weights, ...) } #' @rdname mcse_quantile @@ -431,6 +467,18 @@ mcse_median <- function(x, ...) { as.vector((th2 - th1) / 2) } +.mcse_quantile_weighted <- function(x, prob, weights) { + ess <- ess_quantile(x, prob, weights = weights) + p <- c(0.1586553, 0.8413447) + a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) + ssims <- sort(x) + S <- length(ssims) + th1 <- ssims[max(floor(a[1] * S), 1)] + th2 <- ssims[min(ceiling(a[2] * S), S)] + as.vector((th2 - th1) / 2) +} + + #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a @@ -458,14 +506,15 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_mean_weighted(x, weights, ...) + .mcse_weighted(x, weights, ...) } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights, ...) } #' Monte Carlo standard error for the standard deviation @@ -514,7 +563,8 @@ mcse_sd.default <- function(x, ...) { #' @rdname mcse_sd #' @export mcse_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_sd, weights = weights, ...) } #' Compute Quantiles @@ -793,10 +843,22 @@ fold_draws <- function(x) { ess } -.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, r_eff = 1, ...) { # Vehtari et al. 2022 equation 6 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + x <- as.numeric(x) + + weighted_mean <- matrixStats::weightedMean(x, w = weights) + + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + +.ess_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 7 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + mcse <- .mcse_weighted(x, weights, r_eff, ...) + + mean((x - weighted_mean)^2) / mcse } # should NA be returned by a convergence diagnostic? diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 6aa9ce0e..60c45cd6 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -66,7 +66,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { } else { # take the max of khat for x * weights and khat for weights - weights_diags <- pareto_khat( weights(x, log = TRUE), are_log_weights = TRUE, @@ -82,10 +81,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { ... ) - print(weights_diags) - - print(product_diags) - dim(product_diags) <- dim(product_diags) %||% length(product_diags) margins <- seq_along(dim(product_diags)) From 8e1e0dfdab3aa180283d21b21f30543ef373eb89 Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 11 Mar 2024 12:38:05 +0200 Subject: [PATCH 04/23] tweak weighted diagnostics --- R/convergence.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index c3b6e348..65e5b0bf 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -850,7 +850,9 @@ fold_draws <- function(x) { weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + out } .ess_weighted <- function(x, weights, r_eff = 1, ...) { From a9ba2b6b5f5ec60b364560355538969d19167caa Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 13 Mar 2024 11:54:03 +0200 Subject: [PATCH 05/23] add r_eff into calculation of weighted ess and mcse --- R/convergence.R | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 65e5b0bf..7a0faa4a 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,8 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -345,10 +346,11 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @export ess_mean.default <- function(x, weights = NULL, ...) { - if (is.null(weights)) { + if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -506,7 +508,8 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .mcse_weighted(x, weights, ...) / r_eff } } @@ -843,23 +846,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - - weighted_mean <- matrixStats::weightedMean(x, w = weights) - out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + weighted_mean <- matrixStats::weightedMean(x, w = weights) - out + (weights^2 %*% (x - c(weighted_mean))^2) } -.ess_weighted <- function(x, weights, r_eff = 1, ...) { +.ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, r_eff, ...) - + mcse <- .mcse_weighted(x, weights, ...) / r_eff + mean((x - weighted_mean)^2) / mcse } @@ -882,4 +883,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } - + From ccdfb2fff751bc2c0ffc77bd5a5990f12c2431b0 Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 15 Mar 2024 17:46:57 +0200 Subject: [PATCH 06/23] fixes to weighted ess and mcse --- R/convergence.R | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 7a0faa4a..3e49c682 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,7 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -274,7 +274,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights if (is.null(weights)) { out <- ulapply(probs, .ess_quantile, x = x) } else { - out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, r_eff = r_eff, ...) } if (names) { names(out) <- paste0("ess_q", probs * 100) @@ -309,7 +310,7 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } -.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { +.ess_quantile_weighted <- function(x, prob, weights, r_eff) { if (should_return_NA(x)) { return(NA_real_) } @@ -318,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= weighted_quantile(x, prob, weights) + I <- x <= quantile(x, prob) .ess_weighted(I, weights = weights, r_eff = r_eff) } @@ -389,7 +390,8 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x-mean(x)))) } else { - .ess_weighted(abs(x - mean(x)), weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -435,7 +437,8 @@ mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weight if (is.null(weights)) { out <- ulapply(probs, .mcse_quantile, x = x) } else { - out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) / r_eff } if (names) { names(out) <- paste0("mcse_q", probs * 100) @@ -846,22 +849,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, ...) { +.mcse_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - weighted_mean <- matrixStats::weightedMean(x, w = weights) - (weights^2 %*% (x - c(weighted_mean))^2) + sqrt(weights^2 %*% (x - c(weighted_mean))^2 / r_eff) } .ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, ...) / r_eff + mcse <- .mcse_weighted(x, weights, r_eff, ...) - mean((x - weighted_mean)^2) / mcse + var <- mean((x - mean(x))^2) + var / mcse^2 } # should NA be returned by a convergence diagnostic? From eed7221cbd1604ad06a30df60f411b858495f9cb Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:02:09 +0200 Subject: [PATCH 07/23] use weighted quantile in weighted mcse for quantile --- R/convergence.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index 3e49c682..ccf265c4 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -319,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- x <= weighted_quantile(x, prob, weights = weights) .ess_weighted(I, weights = weights, r_eff = r_eff) } From 2c70d46e3671cc5cd0ef150c6ab623a32029a104 Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:51:41 +0200 Subject: [PATCH 08/23] add tests for weighted convergence measures --- R/convergence.R | 2 +- tests/testthat/test-convergence.R | 58 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index ccf265c4..2c78912b 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -512,7 +512,7 @@ mcse_mean.default <- function(x, weights = NULL, ...) { sd(x) / sqrt(ess_mean(x)) } else { r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - .mcse_weighted(x, weights, ...) / r_eff + .mcse_weighted(x, weights, r_eff, ...) } } diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index 4f9621aa..a471e75b 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -149,3 +149,61 @@ test_that("autocovariance returns correct results", { test_that("NA quantile2 works", { expect_equal(quantile2(NA_real_, c(0.25, 0.75)), c(q25 = NA_real_, q75 = NA_real_)) }) + + +test_that("weighted convergence measures work", { + + # draws from standard normal + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 0.5) + # here, ess should be higher for mean + # mcse should be lower for mean + w1 <- as.numeric(dnorm(x, sd = 0.5) / dnorm(x)) + w1 <- w1 / sum(w1) + xw1 <- weight_draws(xr, weights = w1) + + expect_true(ess_mean(xw1) > ess_mean(xr)) + expect_true(mcse_mean(xw1) < mcse_mean(xr)) + expect_true(ess_quantile(xw1, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw1, probs = 0.95) > ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw1, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw1, probs = 0.95) < mcse_quantile(xr, probs = 0.95)) + + # target is normal(0, 1.2) + # here ess should be lower, and mcse should be higher + w2 <- as.numeric(dnorm(x, sd = 1.2) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + expect_true(ess_mean(xw2) < ess_mean(xr)) + expect_true(mcse_mean(xw2) > mcse_mean(xr)) + + expect_true(ess_quantile(xw2, probs = 0.05) < ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw2, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw2, probs = 0.05) > mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw2, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + + + # target is normal(1, 1) + # here ess for mean and q95 should be lower, but for q5 it should be higher + w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) + w3 <- w3 / sum(w3) + + xw3 <- weight_draws(xr, weights = w3) + expect_true(ess_mean(xw3) < ess_mean(xr)) + expect_true(mcse_mean(xw3) > mcse_mean(xr)) + + expect_true(ess_quantile(xw3, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw3, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw3, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw3, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + +}) From 0bfe6e658699ab1ec18aaf826181429f88a16194 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 17 Jan 2024 17:28:32 +0200 Subject: [PATCH 09/23] updating pareto functions for weighted rvars --- R/pareto_smooth.R | 107 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index a0bb8486..6aa9ce0e 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -46,21 +46,57 @@ pareto_khat.default <- function(x, #' @rdname pareto_khat #' @export -pareto_khat.rvar <- function(x, ...) { - draws_diags <- summarise_rvar_by_element_with_chains( - x, - pareto_smooth.default, - return_k = TRUE, - smooth_draws = FALSE, - ... - ) - dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) - margins <- seq_along(dim(draws_diags)) +pareto_khat.rvar <- function(x, verbose = FALSE, ...) { + if (is.null(weights(x))) { + draws_diags <- summarise_rvar_by_element_with_chains( + x, + pareto_smooth.default, + return_k = TRUE, + smooth_draws = FALSE, + verbose = verbose, + ... + ) - diags <- list( - khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) - ) + dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) + margins <- seq_along(dim(draws_diags)) + + diags <- list( + khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) + ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_khat( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_khat, + ... + ) + print(weights_diags) + + print(product_diags) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, + function(x) { + max(x[[1]]$khat, + weights_diags$khat) + }) + ) + } diags } @@ -149,6 +185,8 @@ pareto_diags.default <- function(x, #' @rdname pareto_diags #' @export pareto_diags.rvar <- function(x, ...) { + + if (is.null(weights(x))) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, @@ -167,6 +205,35 @@ pareto_diags.rvar <- function(x, ...) { khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat_threshold), convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$convergence_rate) ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_diags( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_diags, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, function(x) max(x[[1]]$khat, weights_diags$khat)), + min_ss = apply(product_diags, margins, function(x) max(x[[1]]$min_ss, weights_diags$min_ss)), + khat_threshold = apply(product_diags, margins, function(x) max(x[[1]]$khat_threshold, weights_diags$khat_threshold)), + convergence_rate = apply(product_diags, margins, function(x) min(x[[1]]$convergence_rate, weights_diags$convergence_rate)) + ) + } diags } @@ -279,7 +346,7 @@ pareto_smooth.default <- function(x, if (are_log_weights) { tail <- "right" } - + tail <- match.arg(tail) S <- length(x) @@ -330,7 +397,7 @@ pareto_smooth.default <- function(x, k <- max(left_k, right_k) x <- smoothed$x - + } else { smoothed <- .pareto_smooth_tail( @@ -444,7 +511,7 @@ pareto_convergence_rate.rvar <- function(x, ...) { # shift log values for safe exponentiation x <- x - max(x) } - + tail <- match.arg(tail) S <- length(x) @@ -458,10 +525,10 @@ pareto_convergence_rate.rvar <- function(x, ...) { draws_tail <- ord$x[tail_ids] cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values - + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -617,7 +684,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- NULL if (!are_weights) { - + if (khat > 1) { msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { @@ -630,7 +697,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { } } else { if (khat > khat_threshold || khat > 0.7) { - msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") + msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message("Pareto k-hat = ", round(khat, 2), ".", msg) From 691d24088d669368fa6ac1b6c0397bcca3469626 Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 28 Feb 2024 11:39:33 +0200 Subject: [PATCH 10/23] start on weighted convergence --- R/convergence.R | 30 +++++++++++++++++++++++------- R/summarise_draws.R | 2 +- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index bbb8b12e..270020b5 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -182,7 +182,7 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export -ess_bulk.default <- function(x, ...) { +ess_bulk.default <- function(x, weights = NULL, ...) { .ess(z_scale(.split_chains(x))) } @@ -319,14 +319,19 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @rdname ess_quantile #' @export -ess_mean.default <- function(x, ...) { - .ess(.split_chains(x)) +ess_mean.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + .ess(.split_chains(x)) + } else { + .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, ...) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) } #' Effective sample size for the standard deviation @@ -449,14 +454,18 @@ mcse_mean <- function(x, ...) UseMethod("mcse_mean") #' @rdname mcse_mean #' @export -mcse_mean.default <- function(x, ...) { - sd(x) / sqrt(ess_mean(x)) +mcse_mean.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + sd(x) / sqrt(ess_mean(x)) + } else { + .mcse_mean_weighted(x, weights, ...) + } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, ...) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) } #' Monte Carlo standard error for the standard deviation @@ -784,6 +793,12 @@ fold_draws <- function(x) { ess } +.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 6 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + # should NA be returned by a convergence diagnostic? should_return_NA <- function(x, tol = .Machine$double.eps) { if (anyNA(x) || checkmate::anyInfinite(x)) { @@ -803,3 +818,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } + diff --git a/R/summarise_draws.R b/R/summarise_draws.R index 6f13755a..05acd83b 100644 --- a/R/summarise_draws.R +++ b/R/summarise_draws.R @@ -329,9 +329,9 @@ empty_draws_summary <- function(dimensions = "variable") { create_summary_list <- function(x, v, funs, .args) { draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE) - args <- c(list(draws), .args) v_summary <- named_list(names(funs)) for (m in names(funs)) { + args <- c(list(draws), .args[[m]]) v_summary[[m]] <- do.call(funs[[m]], args) } v_summary From 74b3b71d51f95c0f66bef4eb8f697e1826700940 Mon Sep 17 00:00:00 2001 From: n-kall Date: Thu, 7 Mar 2024 11:59:53 +0200 Subject: [PATCH 11/23] improvements to weighted ess, mcse --- R/convergence.R | 108 ++++++++++++++++++++++++++++++++++++---------- R/pareto_smooth.R | 5 --- 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 270020b5..c3b6e348 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -183,13 +183,18 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export ess_bulk.default <- function(x, weights = NULL, ...) { - .ess(z_scale(.split_chains(x))) + if (is.null(weights)) { + .ess(z_scale(.split_chains(x))) + } else { + .ess_weighted(x, weights, ...) + } } #' @rdname ess_bulk #' @export ess_bulk.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_bulk, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_bulk, weights = weights, ...) } #' Tail effective sample size (tail-ESS) @@ -220,16 +225,17 @@ ess_tail <- function(x, ...) UseMethod("ess_tail") #' @rdname ess_tail #' @export -ess_tail.default <- function(x, ...) { - q05_ess <- ess_quantile(x, 0.05) - q95_ess <- ess_quantile(x, 0.95) +ess_tail.default <- function(x, weights = NULL, ...) { + q05_ess <- ess_quantile(x, 0.05, weights = weights, ...) + q95_ess <- ess_quantile(x, 0.95, weights = weights, ...) min(q05_ess, q95_ess) } #' @rdname ess_tail #' @export ess_tail.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_tail, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_tail, weights = weights, ...) } #' Effective sample sizes for quantiles @@ -258,13 +264,17 @@ ess_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname ess_quantile #' @export -ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .ess_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .ess_quantile, x = x) + } else { + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + } if (names) { names(out) <- paste0("ess_q", probs * 100) } @@ -274,7 +284,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { #' @rdname ess_quantile #' @export ess_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_quantile, probs, weights = weights, names, ...) } #' @rdname ess_quantile @@ -297,6 +308,19 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } +.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { + if (should_return_NA(x)) { + return(NA_real_) + } + x <- as.matrix(x) + if (prob == 1) { + len <- length(x) + prob <- (len - 0.5) / len + } + I <- x <= weighted_quantile(x, prob, weights) + .ess_weighted(I, weights = weights, r_eff = r_eff) +} + #' Effective sample size for the mean #' #' Compute an effective sample size estimate for a mean (expectation) @@ -324,14 +348,15 @@ ess_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess(.split_chains(x)) * (1 / sum(weights^2)) / (NROW(x) * NCOL(x)) + .ess_weighted(x, weights, ...) } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights, ...) } #' Effective sample size for the standard deviation @@ -358,14 +383,19 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @rdname ess_sd #' @export -ess_sd.default <- function(x, ...) { - .ess(.split_chains(abs(x-mean(x)))) +ess_sd.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(.split_chains(abs(x-mean(x)))) + } else { + .ess_weighted(abs(x - mean(x)), weights = weights, ...) + } } #' @rdname ess_sd #' @export ess_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } #' Monte Carlo standard error for quantiles @@ -394,23 +424,29 @@ mcse_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname mcse_quantile #' @export -mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .mcse_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .mcse_quantile, x = x) + } else { + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + } if (names) { names(out) <- paste0("mcse_q", probs * 100) } + out } #' @rdname mcse_quantile #' @export mcse_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, weights = weights, ...) } #' @rdname mcse_quantile @@ -431,6 +467,18 @@ mcse_median <- function(x, ...) { as.vector((th2 - th1) / 2) } +.mcse_quantile_weighted <- function(x, prob, weights) { + ess <- ess_quantile(x, prob, weights = weights) + p <- c(0.1586553, 0.8413447) + a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) + ssims <- sort(x) + S <- length(ssims) + th1 <- ssims[max(floor(a[1] * S), 1)] + th2 <- ssims[min(ceiling(a[2] * S), S)] + as.vector((th2 - th1) / 2) +} + + #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a @@ -458,14 +506,15 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_mean_weighted(x, weights, ...) + .mcse_weighted(x, weights, ...) } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights(x), ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights, ...) } #' Monte Carlo standard error for the standard deviation @@ -514,7 +563,8 @@ mcse_sd.default <- function(x, ...) { #' @rdname mcse_sd #' @export mcse_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_sd, weights = weights, ...) } #' Compute Quantiles @@ -793,10 +843,22 @@ fold_draws <- function(x) { ess } -.mcse_mean_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, r_eff = 1, ...) { # Vehtari et al. 2022 equation 6 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + x <- as.numeric(x) + + weighted_mean <- matrixStats::weightedMean(x, w = weights) + + weights^2 %*% (x - c(weighted_mean))^2 / r_eff +} + +.ess_weighted <- function(x, weights, r_eff = 1, ...) { + # Vehtari et al. 2022 equation 7 + weighted_mean <- matrixStats::weightedMean(x, w = weights) + mcse <- .mcse_weighted(x, weights, r_eff, ...) + + mean((x - weighted_mean)^2) / mcse } # should NA be returned by a convergence diagnostic? diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 6aa9ce0e..60c45cd6 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -66,7 +66,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { } else { # take the max of khat for x * weights and khat for weights - weights_diags <- pareto_khat( weights(x, log = TRUE), are_log_weights = TRUE, @@ -82,10 +81,6 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { ... ) - print(weights_diags) - - print(product_diags) - dim(product_diags) <- dim(product_diags) %||% length(product_diags) margins <- seq_along(dim(product_diags)) From a1b6564e37de4486d2a5547d02b2ca9681ce07af Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 11 Mar 2024 12:38:05 +0200 Subject: [PATCH 12/23] tweak weighted diagnostics --- R/convergence.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index c3b6e348..65e5b0bf 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -850,7 +850,9 @@ fold_draws <- function(x) { weighted_mean <- matrixStats::weightedMean(x, w = weights) - weights^2 %*% (x - c(weighted_mean))^2 / r_eff + out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + + out } .ess_weighted <- function(x, weights, r_eff = 1, ...) { From 6ccf496b512614186589ebd00b060756108aaecf Mon Sep 17 00:00:00 2001 From: n-kall Date: Wed, 13 Mar 2024 11:54:03 +0200 Subject: [PATCH 13/23] add r_eff into calculation of weighted ess and mcse --- R/convergence.R | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 65e5b0bf..7a0faa4a 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,8 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -345,10 +346,11 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @export ess_mean.default <- function(x, weights = NULL, ...) { - if (is.null(weights)) { + if (is.null(weights)) { .ess(.split_chains(x)) } else { - .ess_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -506,7 +508,8 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { - .mcse_weighted(x, weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .mcse_weighted(x, weights, ...) / r_eff } } @@ -843,23 +846,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, r_eff = 1, ...) { +.mcse_weighted <- function(x, weights, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - - weighted_mean <- matrixStats::weightedMean(x, w = weights) - out <- weights^2 %*% (x - c(weighted_mean))^2 / r_eff + weighted_mean <- matrixStats::weightedMean(x, w = weights) - out + (weights^2 %*% (x - c(weighted_mean))^2) } -.ess_weighted <- function(x, weights, r_eff = 1, ...) { +.ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, r_eff, ...) - + mcse <- .mcse_weighted(x, weights, ...) / r_eff + mean((x - weighted_mean)^2) / mcse } @@ -882,4 +883,4 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } - + From e47fd921e971db2617572b7f757b456ebc043b0d Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 15 Mar 2024 17:46:57 +0200 Subject: [PATCH 14/23] fixes to weighted ess and mcse --- R/convergence.R | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 7a0faa4a..3e49c682 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -186,7 +186,7 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -274,7 +274,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights if (is.null(weights)) { out <- ulapply(probs, .ess_quantile, x = x) } else { - out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, r_eff = r_eff, ...) } if (names) { names(out) <- paste0("ess_q", probs * 100) @@ -309,7 +310,7 @@ ess_median <- function(x, ...) { .ess(.split_chains(I)) } -.ess_quantile_weighted <- function(x, prob, weights, r_eff = 1) { +.ess_quantile_weighted <- function(x, prob, weights, r_eff) { if (should_return_NA(x)) { return(NA_real_) } @@ -318,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= weighted_quantile(x, prob, weights) + I <- x <= quantile(x, prob) .ess_weighted(I, weights = weights, r_eff = r_eff) } @@ -389,7 +390,8 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x-mean(x)))) } else { - .ess_weighted(abs(x - mean(x)), weights = weights, ...) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -435,7 +437,8 @@ mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weight if (is.null(weights)) { out <- ulapply(probs, .mcse_quantile, x = x) } else { - out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) / r_eff } if (names) { names(out) <- paste0("mcse_q", probs * 100) @@ -846,22 +849,21 @@ fold_draws <- function(x) { ess } -.mcse_weighted <- function(x, weights, ...) { +.mcse_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 6 x <- as.numeric(x) - weighted_mean <- matrixStats::weightedMean(x, w = weights) - (weights^2 %*% (x - c(weighted_mean))^2) + sqrt(weights^2 %*% (x - c(weighted_mean))^2 / r_eff) } .ess_weighted <- function(x, weights, r_eff, ...) { # Vehtari et al. 2022 equation 7 - weighted_mean <- matrixStats::weightedMean(x, w = weights) - mcse <- .mcse_weighted(x, weights, ...) / r_eff + mcse <- .mcse_weighted(x, weights, r_eff, ...) - mean((x - weighted_mean)^2) / mcse + var <- mean((x - mean(x))^2) + var / mcse^2 } # should NA be returned by a convergence diagnostic? From 1734e3d1e3f54a8565bb321a7035c040779d2954 Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:02:09 +0200 Subject: [PATCH 15/23] use weighted quantile in weighted mcse for quantile --- R/convergence.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index 3e49c682..ccf265c4 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -319,7 +319,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- x <= weighted_quantile(x, prob, weights = weights) .ess_weighted(I, weights = weights, r_eff = r_eff) } From 222893a7ada08ecea07877168652ad7d984daa8a Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 19 Mar 2024 17:51:41 +0200 Subject: [PATCH 16/23] add tests for weighted convergence measures --- R/convergence.R | 2 +- tests/testthat/test-convergence.R | 58 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/R/convergence.R b/R/convergence.R index ccf265c4..2c78912b 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -512,7 +512,7 @@ mcse_mean.default <- function(x, weights = NULL, ...) { sd(x) / sqrt(ess_mean(x)) } else { r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - .mcse_weighted(x, weights, ...) / r_eff + .mcse_weighted(x, weights, r_eff, ...) } } diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index 4f9621aa..a471e75b 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -149,3 +149,61 @@ test_that("autocovariance returns correct results", { test_that("NA quantile2 works", { expect_equal(quantile2(NA_real_, c(0.25, 0.75)), c(q25 = NA_real_, q75 = NA_real_)) }) + + +test_that("weighted convergence measures work", { + + # draws from standard normal + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 0.5) + # here, ess should be higher for mean + # mcse should be lower for mean + w1 <- as.numeric(dnorm(x, sd = 0.5) / dnorm(x)) + w1 <- w1 / sum(w1) + xw1 <- weight_draws(xr, weights = w1) + + expect_true(ess_mean(xw1) > ess_mean(xr)) + expect_true(mcse_mean(xw1) < mcse_mean(xr)) + expect_true(ess_quantile(xw1, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw1, probs = 0.95) > ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw1, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw1, probs = 0.95) < mcse_quantile(xr, probs = 0.95)) + + # target is normal(0, 1.2) + # here ess should be lower, and mcse should be higher + w2 <- as.numeric(dnorm(x, sd = 1.2) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + expect_true(ess_mean(xw2) < ess_mean(xr)) + expect_true(mcse_mean(xw2) > mcse_mean(xr)) + + expect_true(ess_quantile(xw2, probs = 0.05) < ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw2, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw2, probs = 0.05) > mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw2, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + + + # target is normal(1, 1) + # here ess for mean and q95 should be lower, but for q5 it should be higher + w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) + w3 <- w3 / sum(w3) + + xw3 <- weight_draws(xr, weights = w3) + expect_true(ess_mean(xw3) < ess_mean(xr)) + expect_true(mcse_mean(xw3) > mcse_mean(xr)) + + expect_true(ess_quantile(xw3, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw3, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw3, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw3, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + +}) From a732ad86fcb85f3cdc4d2c327fbdaadd49331fad Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 22 Mar 2024 16:45:17 +0200 Subject: [PATCH 17/23] fix r_eff calculations for each quantity --- R/convergence.R | 46 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 2c78912b..d66376ef 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -82,6 +82,8 @@ rhat_basic.rvar <- function(x, split = TRUE, ...) { #' recommend the improved ESS convergence diagnostics implemented in #' [ess_bulk()] and [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -104,18 +106,25 @@ ess_basic <- function(x, ...) UseMethod("ess_basic") #' @rdname ess_basic #' @export -ess_basic.default <- function(x, split = TRUE, ...) { +ess_basic.default <- function(x, split = TRUE, weights = NULL, ...) { split <- as_one_logical(split) if (split) { x <- .split_chains(x) + .ess(x) + } + + if (is.null(weights)) { + r_eff <- .ess(x) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) } - .ess(x) } #' @rdname ess_basic #' @export -ess_basic.rvar <- function(x, split = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_basic, split, ...) +ess_basic.rvar <- function(x, split = TRUE, weights = weights, ...) { + + summarise_rvar_by_element_with_chains(x, ess_basic, split, weights = weights, ...) + } #' Rhat convergence diagnostic @@ -162,6 +171,8 @@ rhat.rvar <- function(x, ...) { #' rank normalized values using split chains. For the tail effective sample size #' see [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -186,7 +197,7 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } } @@ -206,6 +217,8 @@ ess_bulk.rvar <- function(x, ...) { #' sample sizes for 5% and 95% quantiles. For the bulk effective sample #' size see [ess_bulk()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -241,8 +254,9 @@ ess_tail.rvar <- function(x, ...) { #' Effective sample sizes for quantiles #' -#' Compute effective sample size estimates for quantile estimates of a single -#' variable. +#' Compute effective sample size estimates for quantile estimates of a +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -274,8 +288,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights if (is.null(weights)) { out <- ulapply(probs, .ess_quantile, x = x) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - out <- ulapply(probs, .ess_quantile_weighted, x = x, weights = weights, r_eff = r_eff, ...) + r_eff <- ulapply(probs, .ess_quantile, x = x) / (nrow(x) * ncol(x)) + out <- mapply(.ess_quantile_weighted, prob = probs, r_eff = r_eff, MoreArgs = list(x = x, weights = weights)) } if (names) { names(out) <- paste0("ess_q", probs * 100) @@ -367,6 +381,8 @@ ess_mean.rvar <- function(x, ...) { #' Compute an effective sample size estimate for the standard deviation (SD) #' estimate of a single variable. This is defined as the effective sample size #' estimate for the absolute deviation from mean. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -390,7 +406,7 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x-mean(x)))) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(abs(x-mean(x)))) / (nrow(x) * ncol(x)) .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -405,7 +421,8 @@ ess_sd.rvar <- function(x, ...) { #' Monte Carlo standard error for quantiles #' #' Compute Monte Carlo standard errors for quantile estimates of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -487,7 +504,8 @@ mcse_median <- function(x, ...) { #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -564,6 +582,10 @@ mcse_sd.default <- function(x, ...) { # differentials of the moments can be neglected " varsd <- varvar / Evar / 4 sqrt(varsd) + + + #TODO: add weighted version + } #' @rdname mcse_sd From 6ce80dfd9eea365d6d9b74b17920e5a6bf45603e Mon Sep 17 00:00:00 2001 From: n-kall Date: Fri, 22 Mar 2024 16:59:04 +0200 Subject: [PATCH 18/23] fix weighted mcse for sd --- R/convergence.R | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index d10d937d..ea715fab 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -570,26 +570,33 @@ mcse_sd <- function(x, ...) UseMethod("mcse_sd") #' @rdname mcse_sd #' @export -mcse_sd.default <- function(x, ...) { - # var/sd are not a simple expectation of g(X), e.g. variance - # has (X-E[X])^2. The following ESS is based on a relevant quantity - # in the computation and is empirically a good choice. - sims_c <- x - mean(x) - ess <- ess_mean((sims_c)^2) - # Variance of variance estimate by Kenney and Keeping (1951, p. 141), - # which doesn't assume normality of sims. - Evar <- mean(sims_c^2) - varvar <- (mean(sims_c^4) - Evar^2) / ess - # The first order Taylor series approximation of variance of sd. - # Kenney and Keeping (1951, p. 141) write "...since fluctuations of - # any moment are of order N^{-1/2}, squares and higher powers of - # differentials of the moments can be neglected " - varsd <- varvar / Evar / 4 - sqrt(varsd) - - - #TODO: add weighted version +mcse_sd.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + # var/sd are not a simple expectation of g(X), e.g. variance + # has (X-E[X])^2. The following ESS is based on a relevant quantity + # in the computation and is empirically a good choice. + sims_c <- x - mean(x) + ess <- ess_mean((sims_c)^2) + # Variance of variance estimate by Kenney and Keeping (1951, p. 141), + # which doesn't assume normality of sims. + Evar <- mean(sims_c^2) + varvar <- (mean(sims_c^4) - Evar^2) / ess + # The first order Taylor series approximation of variance of sd. + # Kenney and Keeping (1951, p. 141) write "...since fluctuations of + # any moment are of order N^{-1/2}, squares and higher powers of + # differentials of the moments can be neglected " + varsd <- varvar / Evar / 4 + sqrt(varsd) + + } else { + + sims_c <- x - mean(x) + ess <- ess_mean((sims_c)^2) + r_eff <- ess / (nrow(x) * ncol(x)) + .mcse_weighted(sims_c, weights, r_eff, ...) + } } #' @rdname mcse_sd From c0d8b8e236b0a76efe670566ac191172ebced87b Mon Sep 17 00:00:00 2001 From: n-kall Date: Mon, 8 Apr 2024 13:01:56 +0300 Subject: [PATCH 19/23] fixes to pareto smoothing for weighted draws --- R/pareto_smooth.R | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 60c45cd6..575ee8f5 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -51,8 +51,8 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, - return_k = TRUE, smooth_draws = FALSE, + return_k = TRUE, verbose = verbose, ... ) @@ -74,10 +74,13 @@ pareto_khat.rvar <- function(x, verbose = FALSE, ...) { w <- weights(x) - x <- weight_draws(x, NULL) + xu <- weight_draws(x, NULL) + xu <- xu * rvar(w) + product_diags <- summarise_rvar_by_element_with_chains( - x * rvar(w, nchains = nchains(x)), - pareto_khat, + xu, + pareto_khat.default, + verbose = verbose, ... ) @@ -312,7 +315,7 @@ pareto_smooth.rvar <- function(x, return_k = FALSE, extra_diags = FALSE, ...) { #' @export pareto_smooth.default <- function(x, tail = c("both", "right", "left"), - r_eff = 1, + r_eff = NULL, ndraws_tail = NULL, return_k = FALSE, extra_diags = FALSE, @@ -502,6 +505,8 @@ pareto_convergence_rate.rvar <- function(x, ...) { ... ) { + x <- as.numeric(x) + if (are_log_weights) { # shift log values for safe exponentiation x <- x - max(x) From dbdab633c5c52008dda257484e97e74a28b5808f Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 12:33:52 +0300 Subject: [PATCH 20/23] do not unintentionally merge chains in pareto smoothing --- R/pareto_smooth.R | 2 -- 1 file changed, 2 deletions(-) diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index 575ee8f5..794c7188 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -505,8 +505,6 @@ pareto_convergence_rate.rvar <- function(x, ...) { ... ) { - x <- as.numeric(x) - if (are_log_weights) { # shift log values for safe exponentiation x <- x - max(x) From a6cb3fb6c5388ff8a4e7f1f27348232eccb77671 Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 12:41:23 +0300 Subject: [PATCH 21/23] add test for pareto_khat on weighted rvar --- tests/testthat/test-pareto_smooth.R | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/testthat/test-pareto_smooth.R b/tests/testthat/test-pareto_smooth.R index 89635600..abf540a6 100644 --- a/tests/testthat/test-pareto_smooth.R +++ b/tests/testthat/test-pareto_smooth.R @@ -203,3 +203,29 @@ test_that("pareto_smooth works for log_weights", { expect_true(ps$diagnostics$khat > 0.7) }) + + + +test_that("pareto khat works for weighted rvars", { + + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 1.2), should have high pareto-khat + w2 <- as.numeric(dnorm(x, sd = 5) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + k <- pareto_khat(xw2)$khat + kw <- pareto_khat(w2, are_log_weights = TRUE)$khat + kp <- pareto_khat(draws_of(xw2) * w2)$khat + + expect_true(k > 0.7) + expect_equal(k, max(kw, kp)) +}) From d928bd4134fbee0443f63cae081178757215d9fb Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 12:46:33 +0300 Subject: [PATCH 22/23] updates to weighted mcse for sd --- R/convergence.R | 42 +++++++++++++++++++++---------- tests/testthat/test-convergence.R | 2 +- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index ea715fab..6b4dc5fa 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -324,7 +324,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- (x <= quantile(x, prob)) .ess(.split_chains(I)) } @@ -337,7 +337,7 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= weighted_quantile(x, prob, weights = weights) + I <- (x <= weighted_quantile(x, prob, weights = weights)) .ess_weighted(I, weights = weights, r_eff = r_eff) } @@ -408,9 +408,9 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @export ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { - .ess(.split_chains(abs(x-mean(x)))) + .ess(.split_chains(abs(x - mean(x)))) } else { - r_eff <- .ess(.split_chains(abs(x-mean(x)))) / (nrow(x) * ncol(x)) + r_eff <- .ess(.split_chains(abs(x - mean(x)))) / (nrow(x) * ncol(x)) .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } } @@ -422,6 +422,8 @@ ess_sd.rvar <- function(x, ...) { summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } +# TODO: ess_weights + #' Monte Carlo standard error for quantiles #' #' Compute Monte Carlo standard errors for quantile estimates of a @@ -482,6 +484,7 @@ mcse_median <- function(x, ...) { } # MCSE of a single quantile +# TODO: refer to paper .mcse_quantile <- function(x, prob) { ess <- ess_quantile(x, prob) p <- c(0.1586553, 0.8413447) @@ -499,8 +502,8 @@ mcse_median <- function(x, ...) { a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) ssims <- sort(x) S <- length(ssims) - th1 <- ssims[max(floor(a[1] * S), 1)] - th2 <- ssims[min(ceiling(a[2] * S), S)] + th1 <- ssims[max(floor(a[1] * S), 1)] # adjust to account for weights + th2 <- ssims[min(ceiling(a[2] * S), S)] #adjust to account for weights as.vector((th2 - th1) / 2) } @@ -573,7 +576,7 @@ mcse_sd <- function(x, ...) UseMethod("mcse_sd") mcse_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { - + # var/sd are not a simple expectation of g(X), e.g. variance # has (X-E[X])^2. The following ESS is based on a relevant quantity # in the computation and is empirically a good choice. @@ -582,7 +585,8 @@ mcse_sd.default <- function(x, weights = NULL, ...) { # Variance of variance estimate by Kenney and Keeping (1951, p. 141), # which doesn't assume normality of sims. Evar <- mean(sims_c^2) - varvar <- (mean(sims_c^4) - Evar^2) / ess + varvar <- (mean(sims_c^4) - Evar^2) / ess # (Equation 6.20) + # The first order Taylor series approximation of variance of sd. # Kenney and Keeping (1951, p. 141) write "...since fluctuations of # any moment are of order N^{-1/2}, squares and higher powers of @@ -592,10 +596,23 @@ mcse_sd.default <- function(x, weights = NULL, ...) { } else { - sims_c <- x - mean(x) - ess <- ess_mean((sims_c)^2) - r_eff <- ess / (nrow(x) * ncol(x)) - .mcse_weighted(sims_c, weights, r_eff, ...) + # for weights try varvar weighted / varvar unweighted to see relative efficiency of weights + + first_moment_weighted <- weighted.mean(x, w = weights) + + x_centered <- x - first_moment_weighted + second_moment_weighted <- weighted.mean(x_centered^2, w = weights) + fourth_moment_weighted <- weighted.mean(x_centered^4, w = weights) + + r_eff <- .ess(x_centered^2) / (nrow(x) * ncol(x)) + weighted_ess <- .ess_weighted(x_centered^2, weights = weights, r_eff = r_eff) + + # Kenney and Keeping (1951, eq 6.20) + varvar_weighted <- (fourth_moment_weighted - second_moment_weighted^2) / weighted_ess + + # First-order Taylor series approximation + varsd <- varvar_weighted / second_moment_weighted / 4 + sqrt(varsd) } } @@ -918,4 +935,3 @@ should_return_NA <- function(x, tol = .Machine$double.eps) { # } is_constant(x, tol = tol) } - diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index a471e75b..1e9034a9 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -196,8 +196,8 @@ test_that("weighted convergence measures work", { # here ess for mean and q95 should be lower, but for q5 it should be higher w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) w3 <- w3 / sum(w3) - xw3 <- weight_draws(xr, weights = w3) + expect_true(ess_mean(xw3) < ess_mean(xr)) expect_true(mcse_mean(xw3) > mcse_mean(xr)) From 14c01dea26af8f866c2c75cf3aa74a298181cfeb Mon Sep 17 00:00:00 2001 From: n-kall Date: Tue, 9 Apr 2024 14:29:52 +0300 Subject: [PATCH 23/23] updates to mcse for weighted draws --- R/convergence.R | 64 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/R/convergence.R b/R/convergence.R index 6b4dc5fa..0fb6e95f 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -199,6 +199,13 @@ ess_bulk.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(z_scale(.split_chains(x))) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } @@ -291,6 +298,12 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights out <- ulapply(probs, .ess_quantile, x = x) } else { + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- ulapply(probs, .ess_quantile, x = x) / (nrow(x) * ncol(x)) out <- mapply(.ess_quantile_weighted, prob = probs, r_eff = r_eff, MoreArgs = list(x = x, weights = weights)) @@ -368,6 +381,13 @@ ess_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(x)) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .ess_weighted(x, weights, r_eff = r_eff, ...) } @@ -410,6 +430,13 @@ ess_sd.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { .ess(.split_chains(abs(x - mean(x)))) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(.split_chains(abs(x - mean(x)))) / (nrow(x) * ncol(x)) .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) } @@ -460,8 +487,14 @@ mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weight if (is.null(weights)) { out <- ulapply(probs, .mcse_quantile, x = x) } else { - r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) - out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) / r_eff + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) } if (names) { names(out) <- paste0("mcse_q", probs * 100) @@ -493,6 +526,7 @@ mcse_median <- function(x, ...) { S <- length(ssims) th1 <- ssims[max(floor(a[1] * S), 1)] th2 <- ssims[min(ceiling(a[2] * S), S)] + as.vector((th2 - th1) / 2) } @@ -500,10 +534,15 @@ mcse_median <- function(x, ...) { ess <- ess_quantile(x, prob, weights = weights) p <- c(0.1586553, 0.8413447) a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) - ssims <- sort(x) - S <- length(ssims) - th1 <- ssims[max(floor(a[1] * S), 1)] # adjust to account for weights - th2 <- ssims[min(ceiling(a[2] * S), S)] #adjust to account for weights + x_idx <- order(x) + x_sorted <- x[x_idx] + weights_sorted <- weights[x_idx] + S <- length(x) + + cweights <- cumsum(weights_sorted) + th1 <- x_sorted[max(max(which(cweights < a[1])), 1)] + th2 <- x_sorted[min(min(which(cweights > a[2])), S)] + as.vector((th2 - th1) / 2) } @@ -536,6 +575,13 @@ mcse_mean.default <- function(x, weights = NULL, ...) { if (is.null(weights)) { sd(x) / sqrt(ess_mean(x)) } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) .mcse_weighted(x, weights, r_eff, ...) } @@ -596,6 +642,12 @@ mcse_sd.default <- function(x, weights = NULL, ...) { } else { + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + # for weights try varvar weighted / varvar unweighted to see relative efficiency of weights first_moment_weighted <- weighted.mean(x, w = weights)