aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2022-04-11 07:17:34 -0400
committerKen Kellner <ken@kenkellner.com>2022-04-11 07:17:34 -0400
commit9fccb22af6af1899ed1e3f65d5f833de66770b21 (patch)
tree2264267d23863ba4f3da5ee2316d7b06bc5eeaaf
parentebdb26cf9cbcce2db23547c3cd1224b13016a21a (diff)
Make sure kfold tests work
-rw-r--r--R/fit.R12
-rw-r--r--R/loglik.R15
-rw-r--r--tests/testthat/test_occu.R12
3 files changed, 25 insertions, 14 deletions
diff --git a/R/fit.R b/R/fit.R
index dedb735..5abfc49 100644
--- a/R/fit.R
+++ b/R/fit.R
@@ -45,6 +45,18 @@ remove_placeholders <- function(submodels){
submodels
}
+empty_loo <- function(){
+ out <- list()
+ class(out) <- "psis_loo"
+ out
+}
+
+get_loo <- function(object, cores=getOption("mc.cores", 1)){
+ loglik <- extract_log_lik(object, merge_chains=FALSE)
+ r_eff <- loo::relative_eff(exp(loglik), cores=cores)
+ loo::loo(loglik, r_eff=r_eff, cores=cores)
+}
+
# Function to check if just Stan inputs should be returned
# Used by kfold method
# Pass return_inputs=TRUE in ... when calling stan_*
diff --git a/R/loglik.R b/R/loglik.R
index 1309d64..b638006 100644
--- a/R/loglik.R
+++ b/R/loglik.R
@@ -118,8 +118,7 @@ setMethod("get_loglik", "ubmsFitOccuTTD", function(object, inps, ...){
#'
#' @return A matrix (samples x sites) or array (samples x chains x sites)
#'
-#' @aliases extract_log_lik,ubmsFit-method extract_log_lik,ubmsFitOccuTTD-method
-#' @aliases extract_log_lik,ubmsFitDistsamp-method
+#' @aliases extract_log_lik,ubmsFit-method extract_log_lik,ubmsFitDistsamp-method
#' @export
setGeneric("extract_log_lik",
function(object, parameter_name="log_lik", merge_chains=TRUE){
@@ -143,15 +142,3 @@ setMethod("extract_log_lik", "ubmsFitDistsamp",
function(object, parameter_name = "log_lik", merge_chains=TRUE){
loo::extract_log_lik(object@stanfit, parameter_name, merge_chains)
})
-
-get_loo <- function(object, cores=getOption("mc.cores", 1)){
- loglik <- extract_log_lik(object, merge_chains=FALSE)
- r_eff <- loo::relative_eff(exp(loglik), cores=cores)
- loo::loo(loglik, r_eff=r_eff, cores=cores)
-}
-
-empty_loo <- function(){
- out <- list()
- class(out) <- "psis_loo"
- out
-}
diff --git a/tests/testthat/test_occu.R b/tests/testthat/test_occu.R
index bc4737a..bc6efbd 100644
--- a/tests/testthat/test_occu.R
+++ b/tests/testthat/test_occu.R
@@ -181,3 +181,15 @@ test_that("Fitted/residual methods work with ubmsFitOccu",{
expect_is(rp3, "gtable")
expect_is(mp, "gg")
})
+
+test_that("stan_occu kfold method works",{
+
+set.seed(123)
+kf <- kfold(fit, K=2, quiet=TRUE)
+expect_is(kf, "elpd_generic")
+expect_equal(kf$estimates[1,1], -37.3802, tol=1e-4)
+
+expect_error(kfold(fit, K=2, folds=rep(3,10)))
+expect_error(kfold(fit, K=2, folds=rep(1,5)))
+
+})