diff --git a/R/misc.R b/R/misc.R index 02c8df6a..da26fd0f 100644 --- a/R/misc.R +++ b/R/misc.R @@ -1,4 +1,4 @@ -# initialize a named list + # initialize a named list # @param names names of the elements # @param values optional values of the elements named_list <- function(names, values = NULL) { @@ -209,9 +209,16 @@ escape_all <- function(x) { # numerically stable version of log(sum(exp(x))) log_sum_exp <- function(x) { - max <- max(x) - sum <- sum(exp(x - max)) - max + log(sum) + max <- max(as.numeric(x), warnings = FALSE) + if (max == -Inf) { + res <- 0 + } else if (max == Inf) { + res <- Inf + } else { + sum <- sum(exp(x - max)) + res <- max + log(sum) + } + res } # simple version of destructuring assignment diff --git a/tests/testthat/test-log_sum_exp.R b/tests/testthat/test-log_sum_exp.R new file mode 100644 index 00000000..3e343938 --- /dev/null +++ b/tests/testthat/test-log_sum_exp.R @@ -0,0 +1,21 @@ +# Ensure that log_sum_exp agrees with expected behaviour from log(sum(exp(x))) +test_that("log_sum_exp of for x containing Inf is Inf", { + x <- c(Inf, 1) + expect_equal(log_sum_exp(x), Inf) +}) + +test_that("log_sum_exp of -Inf is -Inf", { + x <- c(-Inf, -Inf) + expect_equal(log_sum_exp(x), -Inf) +}) + +test_that("log_sum_exp of empty input is -Inf", { + x <- numeric(0) + expect_equal(log_sum_exp(x), -Inf) +}) + +test_that("log_sum_exp works", { + set.seed(1) + x <- log(runif(10)) + expect_equal(log_sum_exp(x), log(sum(exp(x)))) +})