From 4a4f2124f2f8f83b6cef98a93e5fbad1c9584921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20S=C3=A4ilynoja?= <33978952+TeemuSailynoja@users.noreply.github.com> Date: Thu, 27 Mar 2025 12:44:36 +0200 Subject: [PATCH 1/2] Update log_sum_exp Add behaviour to the function in the corner cases of infinite values or empty input. --- R/misc.R | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/R/misc.R b/R/misc.R index 02c8df6a..da26fd0f 100644 --- a/R/misc.R +++ b/R/misc.R @@ -1,4 +1,4 @@ -# initialize a named list + # initialize a named list # @param names names of the elements # @param values optional values of the elements named_list <- function(names, values = NULL) { @@ -209,9 +209,16 @@ escape_all <- function(x) { # numerically stable version of log(sum(exp(x))) log_sum_exp <- function(x) { - max <- max(x) - sum <- sum(exp(x - max)) - max + log(sum) + max <- max(as.numeric(x), warnings = FALSE) + if (max == -Inf) { + res <- 0 + } else if (max == Inf) { + res <- Inf + } else { + sum <- sum(exp(x - max)) + res <- max + log(sum) + } + res } # simple version of destructuring assignment From fcec2206b66739adcdd033220b688520397d0d56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20S=C3=A4ilynoja?= Date: Thu, 27 Mar 2025 19:04:28 +0200 Subject: [PATCH 2/2] Test log_sum_exp --- tests/testthat/test-log_sum_exp.R | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tests/testthat/test-log_sum_exp.R diff --git a/tests/testthat/test-log_sum_exp.R b/tests/testthat/test-log_sum_exp.R new file mode 100644 index 00000000..3e343938 --- /dev/null +++ b/tests/testthat/test-log_sum_exp.R @@ -0,0 +1,21 @@ +# Ensure that log_sum_exp agrees with expected behaviour from log(sum(exp(x))) +test_that("log_sum_exp of for x containing Inf is Inf", { + x <- c(Inf, 1) + expect_equal(log_sum_exp(x), Inf) +}) + +test_that("log_sum_exp of -Inf is -Inf", { + x <- c(-Inf, -Inf) + expect_equal(log_sum_exp(x), -Inf) +}) + +test_that("log_sum_exp of empty input is -Inf", { + x <- numeric(0) + expect_equal(log_sum_exp(x), -Inf) +}) + +test_that("log_sum_exp works", { + set.seed(1) + x <- log(runif(10)) + expect_equal(log_sum_exp(x), log(sum(exp(x)))) +})