aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2022-04-04 17:15:40 -0400
committerKen Kellner <ken@kenkellner.com>2022-04-04 17:15:40 -0400
commitf0976d0e19e696ae85c11019c527a6a81b4096e4 (patch)
tree3c14e6d9c7699c1ffaa702ad8a9c58b2d9d95883
parent91a5963074ed3c489afc9d4c3f90fb5cccb6daf7 (diff)
Add RN model kfold support and handle missing sites
-rw-r--r--R/RcppExports.R4
-rw-r--r--R/kfold.R26
-rw-r--r--R/missing.R5
-rw-r--r--src/kfold.cpp38
4 files changed, 68 insertions, 5 deletions
diff --git a/R/RcppExports.R b/R/RcppExports.R
index 8a16411..63f325d 100644
--- a/R/RcppExports.R
+++ b/R/RcppExports.R
@@ -17,6 +17,10 @@ get_loglik_pcount <- function(y, M, si, lammat, pmat, K, Kmin) {
.Call(`_ubms_get_loglik_pcount`, y, M, si, lammat, pmat, K, Kmin)
}
+get_loglik_occuRN <- function(y, M, si, lammat, rmat, K, Kmin) {
+ .Call(`_ubms_get_loglik_occuRN`, y, M, si, lammat, rmat, K, Kmin)
+}
+
simz_pcount <- function(y, lam_post, p_post, K, Kmin, kvals) {
.Call(`_ubms_simz_pcount`, y, lam_post, p_post, K, Kmin, kvals)
}
diff --git a/R/kfold.R b/R/kfold.R
index 27d25f1..a5aa504 100644
--- a/R/kfold.R
+++ b/R/kfold.R
@@ -21,13 +21,17 @@ setMethod("kfold", "ubmsFit", function(x, K=10, folds=NULL, quiet=FALSE, ...){
op <- pbapply::pboptions()
if(quiet) pbapply::pboptions(type = "none")
- ll_all <- pbapply::pblapply(1:K, ll_fold, x, folds)
+ ll_unsort <- pbapply::pblapply(1:K, ll_fold, x, folds)
pbapply::pboptions(op)
- ll_all <- do.call("cbind", ll_all)
+ ll_unsort <- do.call("cbind", ll_unsort)
+
+ # Sort sites back to original order
sort_ind <- unlist(lapply(1:K, function(i) which(folds==i)))
- ll_all <- ll_all[,sort_ind,drop=FALSE]
+ ll <- matrix(NA, nrow=nrow(ll_unsort), ncol=ncol(ll_unsort))
+ ll[,sort_ind] <- ll_unsort
+ ll <- ll[,apply(ll, 2, function(x) !any(is.na(x))),drop=FALSE]
- loo::elpd(ll_all)
+ loo::elpd(ll)
})
@@ -43,7 +47,13 @@ ll_fold <- function(i, object, folds){
cl$return_inputs <- TRUE
inps <- eval(cl)
refit@submodels <- inps$submodels
- get_loglik(refit, inps$stan_data)
+ ll <- get_loglik(refit, inps$stan_data)
+
+ # Handle missing sites
+ out <- matrix(NA, nrow=nrow(ll), ncol=sum(folds==i))
+ not_missing <- which(folds==i) %in% which(!removed_sites(object))
+ out[,not_missing] <- ll
+ out
}
setGeneric("get_loglik", function(object, ...) standardGeneric("get_loglik"))
@@ -63,3 +73,9 @@ setMethod("get_loglik", "ubmsFitPcount", function(object, inps, ...){
p <- t(posterior_linpred(object, submodel="det", transform=TRUE))
get_loglik_pcount(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin)
})
+
+setMethod("get_loglik", "ubmsFitOccuRN", function(object, inps, ...){
+ lam <- t(posterior_linpred(object, submodel="state", transform=TRUE))
+ p <- t(posterior_linpred(object, submodel="det", transform=TRUE))
+ get_loglik_occuRN(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin)
+})
diff --git a/R/missing.R b/R/missing.R
index b5615af..3f4c739 100644
--- a/R/missing.R
+++ b/R/missing.R
@@ -88,3 +88,8 @@ get_row_reps <- function(submodel, response){
if(yl %% n != 0) stop("Invalid model matrix dimensions", call.=FALSE)
yl / n
}
+
+#' @include fit.R
+removed_sites <- function(object){
+ object['state']@missing
+}
diff --git a/src/kfold.cpp b/src/kfold.cpp
index c065819..c9f75f8 100644
--- a/src/kfold.cpp
+++ b/src/kfold.cpp
@@ -83,3 +83,41 @@ arma::mat get_loglik_pcount(arma::vec y, int M, arma::imat si,
return out;
}
+
+// [[Rcpp::export]]
+arma::mat get_loglik_occuRN(arma::vec y, int M, arma::imat si,
+ arma::mat lammat, arma::mat rmat, int K,
+ arma::ivec Kmin){
+
+ int S = lammat.n_cols;
+ mat out(S,M);
+ vec r_s(rmat.n_rows);
+ vec ysub, qsub;
+ int J;
+ double f, g, p, lik;
+
+ for (int s = 0; s < S; s++){
+
+ r_s = rmat.col(s);
+
+ for (int m = 0; m < M; m++){
+ lik = 0.0;
+ ysub = y.subvec(si(m,0), si(m,1));
+ qsub = 1 - r_s.subvec(si(m,0), si(m,1));
+ J = ysub.size();
+ for (int k=Kmin(m); k < (K+1); k++){
+ f = Rf_dpois(k, lammat(m, s), false);
+ g = 0.0;
+ for (int j = 0; j < J; j++){
+ p = 1 - pow(qsub(j), k);
+ g += Rf_dbinom(ysub(j), 1, p, true);
+ }
+ lik += f * exp(g);
+ }
+ out(s,m) = log(lik + DBL_MIN);
+ }
+ }
+
+ return out;
+
+}