diff --git a/R/convergence.R b/R/convergence.R index bbb8b12e..0fb6e95f 100644 --- a/R/convergence.R +++ b/R/convergence.R @@ -82,6 +82,8 @@ rhat_basic.rvar <- function(x, split = TRUE, ...) { #' recommend the improved ESS convergence diagnostics implemented in #' [ess_bulk()] and [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -104,18 +106,27 @@ ess_basic <- function(x, ...) UseMethod("ess_basic") #' @rdname ess_basic #' @export -ess_basic.default <- function(x, split = TRUE, ...) { +ess_basic.default <- function(x, split = TRUE, weights = NULL, ...) { split <- as_one_logical(split) if (split) { x <- .split_chains(x) } - .ess(x) + + if (is.null(weights)) { + .ess(x) + } else { + r_eff <- .ess(x) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) + } } #' @rdname ess_basic #' @export ess_basic.rvar <- function(x, split = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_basic, split, ...) + + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_basic, split, weights = weights, ...) + } #' Rhat convergence diagnostic @@ -162,6 +173,8 @@ rhat.rvar <- function(x, ...) { #' rank normalized values using split chains. For the tail effective sample size #' see [ess_tail()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -182,14 +195,27 @@ ess_bulk <- function(x, ...) UseMethod("ess_bulk") #' @rdname ess_bulk #' @export -ess_bulk.default <- function(x, ...) { - .ess(z_scale(.split_chains(x))) +ess_bulk.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(z_scale(.split_chains(x))) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(z_scale(.split_chains(x))) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) + } } #' @rdname ess_bulk #' @export ess_bulk.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_bulk, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_bulk, weights = weights, ...) } #' Tail effective sample size (tail-ESS) @@ -200,6 +226,8 @@ ess_bulk.rvar <- function(x, ...) { #' sample sizes for 5% and 95% quantiles. For the bulk effective sample #' size see [ess_bulk()]. See Vehtari (2021) for an in-depth #' comparison of different effective sample size estimators. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -220,22 +248,24 @@ ess_tail <- function(x, ...) UseMethod("ess_tail") #' @rdname ess_tail #' @export -ess_tail.default <- function(x, ...) { - q05_ess <- ess_quantile(x, 0.05) - q95_ess <- ess_quantile(x, 0.95) +ess_tail.default <- function(x, weights = NULL, ...) { + q05_ess <- ess_quantile(x, 0.05, weights = weights, ...) + q95_ess <- ess_quantile(x, 0.95, weights = weights, ...) min(q05_ess, q95_ess) } #' @rdname ess_tail #' @export ess_tail.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_tail, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_tail, weights = weights, ...) } #' Effective sample sizes for quantiles #' -#' Compute effective sample size estimates for quantile estimates of a single -#' variable. +#' Compute effective sample size estimates for quantile estimates of a +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -258,13 +288,26 @@ ess_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname ess_quantile #' @export -ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .ess_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .ess_quantile, x = x) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- ulapply(probs, .ess_quantile, x = x) / (nrow(x) * ncol(x)) + out <- mapply(.ess_quantile_weighted, prob = probs, r_eff = r_eff, MoreArgs = list(x = x, weights = weights)) + + } if (names) { names(out) <- paste0("ess_q", probs * 100) } @@ -274,7 +317,8 @@ ess_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { #' @rdname ess_quantile #' @export ess_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, ess_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_quantile, probs, weights = weights, names, ...) } #' @rdname ess_quantile @@ -293,10 +337,23 @@ ess_median <- function(x, ...) { len <- length(x) prob <- (len - 0.5) / len } - I <- x <= quantile(x, prob) + I <- (x <= quantile(x, prob)) .ess(.split_chains(I)) } +.ess_quantile_weighted <- function(x, prob, weights, r_eff) { + if (should_return_NA(x)) { + return(NA_real_) + } + x <- as.matrix(x) + if (prob == 1) { + len <- length(x) + prob <- (len - 0.5) / len + } + I <- (x <= weighted_quantile(x, prob, weights = weights)) + .ess_weighted(I, weights = weights, r_eff = r_eff) +} + #' Effective sample size for the mean #' #' Compute an effective sample size estimate for a mean (expectation) @@ -319,14 +376,28 @@ ess_mean <- function(x, ...) UseMethod("ess_mean") #' @rdname ess_quantile #' @export -ess_mean.default <- function(x, ...) { - .ess(.split_chains(x)) +ess_mean.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + .ess(.split_chains(x)) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .ess_weighted(x, weights, r_eff = r_eff, ...) + } } #' @rdname ess_mean #' @export ess_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_mean, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_mean, weights = weights, ...) } #' Effective sample size for the standard deviation @@ -334,6 +405,8 @@ ess_mean.rvar <- function(x, ...) { #' Compute an effective sample size estimate for the standard deviation (SD) #' estimate of a single variable. This is defined as the effective sample size #' estimate for the absolute deviation from mean. +#' If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -353,20 +426,36 @@ ess_sd <- function(x, ...) UseMethod("ess_sd") #' @rdname ess_sd #' @export -ess_sd.default <- function(x, ...) { - .ess(.split_chains(abs(x-mean(x)))) +ess_sd.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + .ess(.split_chains(abs(x - mean(x)))) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(.split_chains(abs(x - mean(x)))) / (nrow(x) * ncol(x)) + .ess_weighted(abs(x - mean(x)), weights = weights, r_eff = r_eff, ...) + } } #' @rdname ess_sd #' @export ess_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, ess_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, ess_sd, weights = weights, ...) } +# TODO: ess_weights + #' Monte Carlo standard error for quantiles #' #' Compute Monte Carlo standard errors for quantile estimates of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -389,23 +478,36 @@ mcse_quantile <- function(x, probs = c(0.05, 0.95), ...) { #' @rdname mcse_quantile #' @export -mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { +mcse_quantile.default <- function(x, probs = c(0.05, 0.95), names = TRUE, weights = NULL, ...) { probs <- as.numeric(probs) if (any(probs < 0 | probs > 1)) { stop_no_call("'probs' must contain values between 0 and 1.") } names <- as_one_logical(names) - out <- ulapply(probs, .mcse_quantile, x = x) + if (is.null(weights)) { + out <- ulapply(probs, .mcse_quantile, x = x) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + out <- ulapply(probs, .mcse_quantile_weighted, x = x, weights = weights) + } if (names) { names(out) <- paste0("mcse_q", probs * 100) } + out } #' @rdname mcse_quantile #' @export mcse_quantile.rvar <- function(x, probs = c(0.05, 0.95), names = TRUE, ...) { - summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_quantile, probs, names, weights = weights, ...) } #' @rdname mcse_quantile @@ -415,6 +517,7 @@ mcse_median <- function(x, ...) { } # MCSE of a single quantile +# TODO: refer to paper .mcse_quantile <- function(x, prob) { ess <- ess_quantile(x, prob) p <- c(0.1586553, 0.8413447) @@ -423,13 +526,32 @@ mcse_median <- function(x, ...) { S <- length(ssims) th1 <- ssims[max(floor(a[1] * S), 1)] th2 <- ssims[min(ceiling(a[2] * S), S)] + + as.vector((th2 - th1) / 2) +} + +.mcse_quantile_weighted <- function(x, prob, weights) { + ess <- ess_quantile(x, prob, weights = weights) + p <- c(0.1586553, 0.8413447) + a <- qbeta(p, ess * prob + 1, ess * (1 - prob) + 1) + x_idx <- order(x) + x_sorted <- x[x_idx] + weights_sorted <- weights[x_idx] + S <- length(x) + + cweights <- cumsum(weights_sorted) + th1 <- x_sorted[max(max(which(cweights < a[1])), 1)] + th2 <- x_sorted[min(min(which(cweights > a[2])), S)] + as.vector((th2 - th1) / 2) } + #' Monte Carlo standard error for the mean #' #' Compute the Monte Carlo standard error for the mean (expectation) of a -#' single variable. +#' single variable. If computed on a weighted `rvar`, weights will be +#' taken into account. #' #' @family diagnostics #' @template args-conv @@ -449,14 +571,27 @@ mcse_mean <- function(x, ...) UseMethod("mcse_mean") #' @rdname mcse_mean #' @export -mcse_mean.default <- function(x, ...) { - sd(x) / sqrt(ess_mean(x)) +mcse_mean.default <- function(x, weights = NULL, ...) { + if (is.null(weights)) { + sd(x) / sqrt(ess_mean(x)) + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + r_eff <- .ess(.split_chains(x)) / (nrow(x) * ncol(x)) + .mcse_weighted(x, weights, r_eff, ...) + } } #' @rdname mcse_mean #' @export mcse_mean.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_mean, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_mean, weights = weights, ...) } #' Monte Carlo standard error for the standard deviation @@ -484,28 +619,60 @@ mcse_sd <- function(x, ...) UseMethod("mcse_sd") #' @rdname mcse_sd #' @export -mcse_sd.default <- function(x, ...) { - # var/sd are not a simple expectation of g(X), e.g. variance - # has (X-E[X])^2. The following ESS is based on a relevant quantity - # in the computation and is empirically a good choice. - sims_c <- x - mean(x) - ess <- ess_mean((sims_c)^2) - # Variance of variance estimate by Kenney and Keeping (1951, p. 141), - # which doesn't assume normality of sims. - Evar <- mean(sims_c^2) - varvar <- (mean(sims_c^4) - Evar^2) / ess - # The first order Taylor series approximation of variance of sd. - # Kenney and Keeping (1951, p. 141) write "...since fluctuations of - # any moment are of order N^{-1/2}, squares and higher powers of - # differentials of the moments can be neglected " - varsd <- varvar / Evar / 4 - sqrt(varsd) +mcse_sd.default <- function(x, weights = NULL, ...) { + + if (is.null(weights)) { + + # var/sd are not a simple expectation of g(X), e.g. variance + # has (X-E[X])^2. The following ESS is based on a relevant quantity + # in the computation and is empirically a good choice. + sims_c <- x - mean(x) + ess <- ess_mean((sims_c)^2) + # Variance of variance estimate by Kenney and Keeping (1951, p. 141), + # which doesn't assume normality of sims. + Evar <- mean(sims_c^2) + varvar <- (mean(sims_c^4) - Evar^2) / ess # (Equation 6.20) + + # The first order Taylor series approximation of variance of sd. + # Kenney and Keeping (1951, p. 141) write "...since fluctuations of + # any moment are of order N^{-1/2}, squares and higher powers of + # differentials of the moments can be neglected " + varsd <- varvar / Evar / 4 + sqrt(varsd) + + } else { + + # normalise weights + weights <- weights / sum(weights) + + # ensure x has rows and columns + x <- as.matrix(x) + + # for weights try varvar weighted / varvar unweighted to see relative efficiency of weights + + first_moment_weighted <- weighted.mean(x, w = weights) + + x_centered <- x - first_moment_weighted + second_moment_weighted <- weighted.mean(x_centered^2, w = weights) + fourth_moment_weighted <- weighted.mean(x_centered^4, w = weights) + + r_eff <- .ess(x_centered^2) / (nrow(x) * ncol(x)) + weighted_ess <- .ess_weighted(x_centered^2, weights = weights, r_eff = r_eff) + + # Kenney and Keeping (1951, eq 6.20) + varvar_weighted <- (fourth_moment_weighted - second_moment_weighted^2) / weighted_ess + + # First-order Taylor series approximation + varsd <- varvar_weighted / second_moment_weighted / 4 + sqrt(varsd) + } } #' @rdname mcse_sd #' @export mcse_sd.rvar <- function(x, ...) { - summarise_rvar_by_element_with_chains(x, mcse_sd, ...) + weights <- weights(x) + summarise_rvar_by_element_with_chains(x, mcse_sd, weights = weights, ...) } #' Compute Quantiles @@ -784,6 +951,23 @@ fold_draws <- function(x) { ess } +.mcse_weighted <- function(x, weights, r_eff, ...) { + # Vehtari et al. 2022 equation 6 + + x <- as.numeric(x) + weighted_mean <- matrixStats::weightedMean(x, w = weights) + + sqrt(weights^2 %*% (x - c(weighted_mean))^2 / r_eff) +} + +.ess_weighted <- function(x, weights, r_eff, ...) { + # Vehtari et al. 2022 equation 7 + mcse <- .mcse_weighted(x, weights, r_eff, ...) + + var <- mean((x - mean(x))^2) + var / mcse^2 +} + # should NA be returned by a convergence diagnostic? should_return_NA <- function(x, tol = .Machine$double.eps) { if (anyNA(x) || checkmate::anyInfinite(x)) { diff --git a/R/pareto_smooth.R b/R/pareto_smooth.R index a0bb8486..794c7188 100644 --- a/R/pareto_smooth.R +++ b/R/pareto_smooth.R @@ -46,21 +46,55 @@ pareto_khat.default <- function(x, #' @rdname pareto_khat #' @export -pareto_khat.rvar <- function(x, ...) { - draws_diags <- summarise_rvar_by_element_with_chains( - x, - pareto_smooth.default, - return_k = TRUE, - smooth_draws = FALSE, - ... - ) - dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) - margins <- seq_along(dim(draws_diags)) +pareto_khat.rvar <- function(x, verbose = FALSE, ...) { + if (is.null(weights(x))) { + draws_diags <- summarise_rvar_by_element_with_chains( + x, + pareto_smooth.default, + smooth_draws = FALSE, + return_k = TRUE, + verbose = verbose, + ... + ) - diags <- list( - khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) - ) + dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags) + margins <- seq_along(dim(draws_diags)) + + diags <- list( + khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat) + ) + } else { + + # take the max of khat for x * weights and khat for weights + weights_diags <- pareto_khat( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + xu <- weight_draws(x, NULL) + xu <- xu * rvar(w) + + product_diags <- summarise_rvar_by_element_with_chains( + xu, + pareto_khat.default, + verbose = verbose, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + diags <- list( + khat = apply(product_diags, margins, + function(x) { + max(x[[1]]$khat, + weights_diags$khat) + }) + ) + } diags } @@ -149,6 +183,8 @@ pareto_diags.default <- function(x, #' @rdname pareto_diags #' @export pareto_diags.rvar <- function(x, ...) { + + if (is.null(weights(x))) { draws_diags <- summarise_rvar_by_element_with_chains( x, pareto_smooth.default, @@ -167,6 +203,35 @@ pareto_diags.rvar <- function(x, ...) { khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat_threshold), convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$convergence_rate) ) + } else { + + # take the max of khat for x * weights and khat for weights + + weights_diags <- pareto_diags( + weights(x, log = TRUE), + are_log_weights = TRUE, + ... + ) + + w <- weights(x) + + x <- weight_draws(x, NULL) + product_diags <- summarise_rvar_by_element_with_chains( + x * rvar(w, nchains = nchains(x)), + pareto_diags, + ... + ) + + dim(product_diags) <- dim(product_diags) %||% length(product_diags) + margins <- seq_along(dim(product_diags)) + + diags <- list( + khat = apply(product_diags, margins, function(x) max(x[[1]]$khat, weights_diags$khat)), + min_ss = apply(product_diags, margins, function(x) max(x[[1]]$min_ss, weights_diags$min_ss)), + khat_threshold = apply(product_diags, margins, function(x) max(x[[1]]$khat_threshold, weights_diags$khat_threshold)), + convergence_rate = apply(product_diags, margins, function(x) min(x[[1]]$convergence_rate, weights_diags$convergence_rate)) + ) + } diags } @@ -250,7 +315,7 @@ pareto_smooth.rvar <- function(x, return_k = FALSE, extra_diags = FALSE, ...) { #' @export pareto_smooth.default <- function(x, tail = c("both", "right", "left"), - r_eff = 1, + r_eff = NULL, ndraws_tail = NULL, return_k = FALSE, extra_diags = FALSE, @@ -279,7 +344,7 @@ pareto_smooth.default <- function(x, if (are_log_weights) { tail <- "right" } - + tail <- match.arg(tail) S <- length(x) @@ -330,7 +395,7 @@ pareto_smooth.default <- function(x, k <- max(left_k, right_k) x <- smoothed$x - + } else { smoothed <- .pareto_smooth_tail( @@ -444,7 +509,7 @@ pareto_convergence_rate.rvar <- function(x, ...) { # shift log values for safe exponentiation x <- x - max(x) } - + tail <- match.arg(tail) S <- length(x) @@ -458,10 +523,10 @@ pareto_convergence_rate.rvar <- function(x, ...) { draws_tail <- ord$x[tail_ids] cutoff <- ord$x[min(tail_ids) - 1] # largest value smaller than tail values - + max_tail <- max(draws_tail) min_tail <- min(draws_tail) - + if (ndraws_tail >= 5) { ord <- sort.int(x, index.return = TRUE) if (abs(max_tail - min_tail) < .Machine$double.eps / 100) { @@ -617,7 +682,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { msg <- NULL if (!are_weights) { - + if (khat > 1) { msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.") } else { @@ -630,7 +695,7 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) { } } else { if (khat > khat_threshold || khat > 0.7) { - msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") + msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n") } } message("Pareto k-hat = ", round(khat, 2), ".", msg) diff --git a/R/summarise_draws.R b/R/summarise_draws.R index 6f13755a..05acd83b 100644 --- a/R/summarise_draws.R +++ b/R/summarise_draws.R @@ -329,9 +329,9 @@ empty_draws_summary <- function(dimensions = "variable") { create_summary_list <- function(x, v, funs, .args) { draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE) - args <- c(list(draws), .args) v_summary <- named_list(names(funs)) for (m in names(funs)) { + args <- c(list(draws), .args[[m]]) v_summary[[m]] <- do.call(funs[[m]], args) } v_summary diff --git a/tests/testthat/test-convergence.R b/tests/testthat/test-convergence.R index 4f9621aa..1e9034a9 100644 --- a/tests/testthat/test-convergence.R +++ b/tests/testthat/test-convergence.R @@ -149,3 +149,61 @@ test_that("autocovariance returns correct results", { test_that("NA quantile2 works", { expect_equal(quantile2(NA_real_, c(0.25, 0.75)), c(q25 = NA_real_, q75 = NA_real_)) }) + + +test_that("weighted convergence measures work", { + + # draws from standard normal + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 0.5) + # here, ess should be higher for mean + # mcse should be lower for mean + w1 <- as.numeric(dnorm(x, sd = 0.5) / dnorm(x)) + w1 <- w1 / sum(w1) + xw1 <- weight_draws(xr, weights = w1) + + expect_true(ess_mean(xw1) > ess_mean(xr)) + expect_true(mcse_mean(xw1) < mcse_mean(xr)) + expect_true(ess_quantile(xw1, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw1, probs = 0.95) > ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw1, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw1, probs = 0.95) < mcse_quantile(xr, probs = 0.95)) + + # target is normal(0, 1.2) + # here ess should be lower, and mcse should be higher + w2 <- as.numeric(dnorm(x, sd = 1.2) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + expect_true(ess_mean(xw2) < ess_mean(xr)) + expect_true(mcse_mean(xw2) > mcse_mean(xr)) + + expect_true(ess_quantile(xw2, probs = 0.05) < ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw2, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw2, probs = 0.05) > mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw2, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + + + # target is normal(1, 1) + # here ess for mean and q95 should be lower, but for q5 it should be higher + w3 <- as.numeric(dnorm(x, mean = 1, sd = 1) / dnorm(x)) + w3 <- w3 / sum(w3) + xw3 <- weight_draws(xr, weights = w3) + + expect_true(ess_mean(xw3) < ess_mean(xr)) + expect_true(mcse_mean(xw3) > mcse_mean(xr)) + + expect_true(ess_quantile(xw3, probs = 0.05) > ess_quantile(xr, probs = 0.05)) + expect_true(ess_quantile(xw3, probs = 0.95) < ess_quantile(xr, probs = 0.95)) + expect_true(mcse_quantile(xw3, probs = 0.05) < mcse_quantile(xr, probs = 0.05)) + expect_true(mcse_quantile(xw3, probs = 0.95) > mcse_quantile(xr, probs = 0.95)) + +}) diff --git a/tests/testthat/test-pareto_smooth.R b/tests/testthat/test-pareto_smooth.R index 89635600..abf540a6 100644 --- a/tests/testthat/test-pareto_smooth.R +++ b/tests/testthat/test-pareto_smooth.R @@ -203,3 +203,29 @@ test_that("pareto_smooth works for log_weights", { expect_true(ps$diagnostics$khat > 0.7) }) + + + +test_that("pareto khat works for weighted rvars", { + + x <- cbind( + rnorm(100), + rnorm(100), + rnorm(100), + rnorm(100) + ) + + xr <- rvar(x, with_chains = TRUE) + + # target is normal(0, 1.2), should have high pareto-khat + w2 <- as.numeric(dnorm(x, sd = 5) / dnorm(x)) + w2 <- w2 / sum(w2) + xw2 <- weight_draws(xr, weights = w2) + + k <- pareto_khat(xw2)$khat + kw <- pareto_khat(w2, are_log_weights = TRUE)$khat + kp <- pareto_khat(draws_of(xw2) * w2)$khat + + expect_true(k > 0.7) + expect_equal(k, max(kw, kp)) +})