diff options
author | Ken Kellner <ken@kenkellner.com> | 2022-04-11 07:17:34 -0400 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2022-04-11 07:17:34 -0400 |
commit | 9fccb22af6af1899ed1e3f65d5f833de66770b21 (patch) | |
tree | 2264267d23863ba4f3da5ee2316d7b06bc5eeaaf | |
parent | ebdb26cf9cbcce2db23547c3cd1224b13016a21a (diff) |
Make sure kfold tests work
-rw-r--r-- | R/fit.R | 12 | ||||
-rw-r--r-- | R/loglik.R | 15 | ||||
-rw-r--r-- | tests/testthat/test_occu.R | 12 |
3 files changed, 25 insertions, 14 deletions
@@ -45,6 +45,18 @@ remove_placeholders <- function(submodels){ submodels } +empty_loo <- function(){ + out <- list() + class(out) <- "psis_loo" + out +} + +get_loo <- function(object, cores=getOption("mc.cores", 1)){ + loglik <- extract_log_lik(object, merge_chains=FALSE) + r_eff <- loo::relative_eff(exp(loglik), cores=cores) + loo::loo(loglik, r_eff=r_eff, cores=cores) +} + # Function to check if just Stan inputs should be returned # Used by kfold method # Pass return_inputs=TRUE in ... when calling stan_* @@ -118,8 +118,7 @@ setMethod("get_loglik", "ubmsFitOccuTTD", function(object, inps, ...){ #' #' @return A matrix (samples x sites) or array (samples x chains x sites) #' -#' @aliases extract_log_lik,ubmsFit-method extract_log_lik,ubmsFitOccuTTD-method -#' @aliases extract_log_lik,ubmsFitDistsamp-method +#' @aliases extract_log_lik,ubmsFit-method extract_log_lik,ubmsFitDistsamp-method #' @export setGeneric("extract_log_lik", function(object, parameter_name="log_lik", merge_chains=TRUE){ @@ -143,15 +142,3 @@ setMethod("extract_log_lik", "ubmsFitDistsamp", function(object, parameter_name = "log_lik", merge_chains=TRUE){ loo::extract_log_lik(object@stanfit, parameter_name, merge_chains) }) - -get_loo <- function(object, cores=getOption("mc.cores", 1)){ - loglik <- extract_log_lik(object, merge_chains=FALSE) - r_eff <- loo::relative_eff(exp(loglik), cores=cores) - loo::loo(loglik, r_eff=r_eff, cores=cores) -} - -empty_loo <- function(){ - out <- list() - class(out) <- "psis_loo" - out -} diff --git a/tests/testthat/test_occu.R b/tests/testthat/test_occu.R index bc4737a..bc6efbd 100644 --- a/tests/testthat/test_occu.R +++ b/tests/testthat/test_occu.R @@ -181,3 +181,15 @@ test_that("Fitted/residual methods work with ubmsFitOccu",{ expect_is(rp3, "gtable") expect_is(mp, "gg") }) + +test_that("stan_occu kfold method works",{ + +set.seed(123) +kf <- kfold(fit, K=2, quiet=TRUE) +expect_is(kf, "elpd_generic") +expect_equal(kf$estimates[1,1], -37.3802, tol=1e-4) + +expect_error(kfold(fit, K=2, folds=rep(3,10))) +expect_error(kfold(fit, K=2, folds=rep(1,5))) + +}) |