diff options
author | Ken Kellner <ken@kenkellner.com> | 2022-04-04 17:15:40 -0400 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2022-04-04 17:15:40 -0400 |
commit | f0976d0e19e696ae85c11019c527a6a81b4096e4 (patch) | |
tree | 3c14e6d9c7699c1ffaa702ad8a9c58b2d9d95883 | |
parent | 91a5963074ed3c489afc9d4c3f90fb5cccb6daf7 (diff) |
Add RN model kfold support and handle missing sites
-rw-r--r-- | R/RcppExports.R | 4 | ||||
-rw-r--r-- | R/kfold.R | 26 | ||||
-rw-r--r-- | R/missing.R | 5 | ||||
-rw-r--r-- | src/kfold.cpp | 38 |
4 files changed, 68 insertions, 5 deletions
diff --git a/R/RcppExports.R b/R/RcppExports.R index 8a16411..63f325d 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -17,6 +17,10 @@ get_loglik_pcount <- function(y, M, si, lammat, pmat, K, Kmin) { .Call(`_ubms_get_loglik_pcount`, y, M, si, lammat, pmat, K, Kmin) } +get_loglik_occuRN <- function(y, M, si, lammat, rmat, K, Kmin) { + .Call(`_ubms_get_loglik_occuRN`, y, M, si, lammat, rmat, K, Kmin) +} + simz_pcount <- function(y, lam_post, p_post, K, Kmin, kvals) { .Call(`_ubms_simz_pcount`, y, lam_post, p_post, K, Kmin, kvals) } @@ -21,13 +21,17 @@ setMethod("kfold", "ubmsFit", function(x, K=10, folds=NULL, quiet=FALSE, ...){ op <- pbapply::pboptions() if(quiet) pbapply::pboptions(type = "none") - ll_all <- pbapply::pblapply(1:K, ll_fold, x, folds) + ll_unsort <- pbapply::pblapply(1:K, ll_fold, x, folds) pbapply::pboptions(op) - ll_all <- do.call("cbind", ll_all) + ll_unsort <- do.call("cbind", ll_unsort) + + # Sort sites back to original order sort_ind <- unlist(lapply(1:K, function(i) which(folds==i))) - ll_all <- ll_all[,sort_ind,drop=FALSE] + ll <- matrix(NA, nrow=nrow(ll_unsort), ncol=ncol(ll_unsort)) + ll[,sort_ind] <- ll_unsort + ll <- ll[,apply(ll, 2, function(x) !any(is.na(x))),drop=FALSE] - loo::elpd(ll_all) + loo::elpd(ll) }) @@ -43,7 +47,13 @@ ll_fold <- function(i, object, folds){ cl$return_inputs <- TRUE inps <- eval(cl) refit@submodels <- inps$submodels - get_loglik(refit, inps$stan_data) + ll <- get_loglik(refit, inps$stan_data) + + # Handle missing sites + out <- matrix(NA, nrow=nrow(ll), ncol=sum(folds==i)) + not_missing <- which(folds==i) %in% which(!removed_sites(object)) + out[,not_missing] <- ll + out } setGeneric("get_loglik", function(object, ...) standardGeneric("get_loglik")) @@ -63,3 +73,9 @@ setMethod("get_loglik", "ubmsFitPcount", function(object, inps, ...){ p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) get_loglik_pcount(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin) }) + +setMethod("get_loglik", "ubmsFitOccuRN", function(object, inps, ...){ + lam <- t(posterior_linpred(object, submodel="state", transform=TRUE)) + p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) + get_loglik_occuRN(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin) +}) diff --git a/R/missing.R b/R/missing.R index b5615af..3f4c739 100644 --- a/R/missing.R +++ b/R/missing.R @@ -88,3 +88,8 @@ get_row_reps <- function(submodel, response){ if(yl %% n != 0) stop("Invalid model matrix dimensions", call.=FALSE) yl / n } + +#' @include fit.R +removed_sites <- function(object){ + object['state']@missing +} diff --git a/src/kfold.cpp b/src/kfold.cpp index c065819..c9f75f8 100644 --- a/src/kfold.cpp +++ b/src/kfold.cpp @@ -83,3 +83,41 @@ arma::mat get_loglik_pcount(arma::vec y, int M, arma::imat si, return out; } + +// [[Rcpp::export]] +arma::mat get_loglik_occuRN(arma::vec y, int M, arma::imat si, + arma::mat lammat, arma::mat rmat, int K, + arma::ivec Kmin){ + + int S = lammat.n_cols; + mat out(S,M); + vec r_s(rmat.n_rows); + vec ysub, qsub; + int J; + double f, g, p, lik; + + for (int s = 0; s < S; s++){ + + r_s = rmat.col(s); + + for (int m = 0; m < M; m++){ + lik = 0.0; + ysub = y.subvec(si(m,0), si(m,1)); + qsub = 1 - r_s.subvec(si(m,0), si(m,1)); + J = ysub.size(); + for (int k=Kmin(m); k < (K+1); k++){ + f = Rf_dpois(k, lammat(m, s), false); + g = 0.0; + for (int j = 0; j < J; j++){ + p = 1 - pow(qsub(j), k); + g += Rf_dbinom(ysub(j), 1, p, true); + } + lik += f * exp(g); + } + out(s,m) = log(lik + DBL_MIN); + } + } + + return out; + +} |