From 897f6fe591350135243f7c26819109e424d573dd Mon Sep 17 00:00:00 2001 From: jgabry Date: Fri, 8 May 2026 12:24:54 +0200 Subject: [PATCH] `$lp_approx()`, `$mle()` return numeric vectors even if default draws format is not a matrix --- R/fit.R | 7 ++++--- tests/testthat/test-fit-laplace.R | 13 +++++++++++++ tests/testthat/test-fit-mle.R | 11 +++++++++++ tests/testthat/test-fit-vb.R | 13 +++++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/R/fit.R b/R/fit.R index ead1f0a1..73df4476 100644 --- a/R/fit.R +++ b/R/fit.R @@ -768,7 +768,8 @@ CmdStanFit$set("public", name = "lp", value = lp) # will be used by a subset of fit objects below #' @rdname fit-method-lp lp_approx <- function() { - as.numeric(self$draws()[, "lp_approx__"]) + x <- self$draws(variables = "lp_approx__", format = "draws_matrix") + as.numeric(x[, "lp_approx__"]) } @@ -1973,8 +1974,8 @@ CmdStanMLE <- R6::R6Class( #' } #' mle <- function(variables = NULL) { - x <- self$draws(variables) - x <- x[, colnames(x) != "lp__"] + x <- self$draws(variables = variables, format = "draws_matrix") + x <- x[, colnames(x) != "lp__", drop = FALSE] stats::setNames(as.numeric(x), posterior::variables(x)) } CmdStanMLE$set("public", name = "mle", value = mle) diff --git a/tests/testthat/test-fit-laplace.R b/tests/testthat/test-fit-laplace.R index 54933eef..b2258a58 100644 --- a/tests/testthat/test-fit-laplace.R +++ b/tests/testthat/test-fit-laplace.R @@ -41,6 +41,19 @@ test_that("lp(), lp_approx() methods return vectors (reading csv works)", { expect_equal(length(lg), length(lp)) }) +test_that("lp_approx() ignores non-matrix default draws formats", { + expected <- fit_laplace$draws( + variables = "lp_approx__", + format = "draws_matrix" + ) + expected <- as.numeric(expected[, "lp_approx__"]) + + for (format in c("draws_array", "draws_df")) { + withr::local_options(list(cmdstanr_draws_format = format)) + expect_equal(fit_laplace$lp_approx(), expected) + } +}) + test_that("time() method works after laplace", { run_times <- fit_laplace$time() checkmate::expect_list(run_times, names = "strict", any.missing = FALSE) diff --git a/tests/testthat/test-fit-mle.R b/tests/testthat/test-fit-mle.R index e65b9205..db970a1c 100644 --- a/tests/testthat/test-fit-mle.R +++ b/tests/testthat/test-fit-mle.R @@ -10,6 +10,17 @@ test_that("mle and lp methods work after optimization", { checkmate::expect_numeric(fit_mle$lp(), len = 1) }) +test_that("mle() ignores non-matrix default draws formats", { + expected <- fit_mle$draws(format = "draws_matrix") + expected <- expected[, colnames(expected) != "lp__", drop = FALSE] + expected <- stats::setNames(as.numeric(expected), posterior::variables(expected)) + + for (format in c("draws_array", "draws_df")) { + withr::local_options(list(cmdstanr_draws_format = format)) + expect_equal(fit_mle$mle(), expected) + } +}) + test_that("summary method works after optimization", { x <- fit_mle$summary() expect_s3_class(x, "draws_summary") diff --git a/tests/testthat/test-fit-vb.R b/tests/testthat/test-fit-vb.R index bac2d079..986f38eb 100644 --- a/tests/testthat/test-fit-vb.R +++ b/tests/testthat/test-fit-vb.R @@ -71,6 +71,19 @@ test_that("lp(), lp_approx() methods return vectors (reading csv works)", { expect_equal(length(lg), length(lp)) }) +test_that("lp_approx() ignores non-matrix default draws formats", { + expected <- fit_vb$draws( + variables = "lp_approx__", + format = "draws_matrix" + ) + expected <- as.numeric(expected[, "lp_approx__"]) + + for (format in c("draws_array", "draws_df")) { + withr::local_options(list(cmdstanr_draws_format = format)) + expect_equal(fit_vb$lp_approx(), expected) + } +}) + test_that("vb works with scientific notation args", { x <- fit_vb_sci_not$summary() expect_s3_class(x, "draws_summary")