diff options
author | Ken Kellner <ken@kenkellner.com> | 2022-04-07 09:46:08 -0400 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2022-04-07 09:46:08 -0400 |
commit | fde76ddf42366bf3c97f425e0f5c572f44198310 (patch) | |
tree | ae544105def8b1f73f1ae510849a2f420419f454 | |
parent | 26c1b71b85887f4b275c4c5b5fb934264f2abc2c (diff) |
colext kfold support and fix some bugs
-rw-r--r-- | R/RcppExports.R | 4 | ||||
-rw-r--r-- | R/kfold.R | 68 | ||||
-rw-r--r-- | inst/stan/colext.stan | 10 | ||||
-rw-r--r-- | src/kfold.cpp | 100 |
4 files changed, 164 insertions, 18 deletions
diff --git a/R/RcppExports.R b/R/RcppExports.R index 92bb1de..504f6d6 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -25,6 +25,10 @@ get_loglik_multinomPois <- function(y, M, si, lammat, pmat, pi_type) { .Call(`_ubms_get_loglik_multinomPois`, y, M, si, lammat, pmat, pi_type) } +get_loglik_colext <- function(y, M, Tsamp, J, si, psicube, phicube, pmat, nd) { + .Call(`_ubms_get_loglik_colext`, y, M, Tsamp, J, si, psicube, phicube, pmat, nd) +} + simz_pcount <- function(y, lam_post, p_post, K, Kmin, kvals) { .Call(`_ubms_simz_pcount`, y, lam_post, p_post, K, Kmin, kvals) } @@ -47,7 +47,7 @@ ll_fold <- function(i, object, folds){ cl$return_inputs <- TRUE inps <- eval(cl) refit@submodels <- inps$submodels - ll <- get_loglik(refit, inps$stan_data) + ll <- get_loglik(refit, inps) # Handle missing sites out <- matrix(NA, nrow=nrow(ll), ncol=sum(folds==i)) @@ -56,6 +56,22 @@ ll_fold <- function(i, object, folds){ out } +# Calculate parameter from X, Z, offset +# Can't use posterior_linpred for this because it doesn't drop NAs +# Inps are output from get_stan_inputs() +calculate_par <- function(object, inps, submodel){ + X <- inps$stan_data[[paste0("X_",submodel)]] + beta <- extract(object, paste0("beta_", submodel))[[1]] + lp <- X %*% t(beta) + if(inps$stan_data[[paste0("has_random_",submodel)]]){ + Z <- Z_matrix(inps$submodels[submodel]) + b <- extract(object, paste0("b_", submodel))[[1]] + lp <- lp + Z %*% t(b) + } + lp <- lp + inps$stan_data[[paste0("offset_",submodel)]] + do.call(inps$submodels[submodel]@link, list(lp)) +} + setGeneric("get_loglik", function(object, ...) standardGeneric("get_loglik")) setMethod("get_loglik", "ubmsFit", function(object, inps, ...){ @@ -63,25 +79,51 @@ setMethod("get_loglik", "ubmsFit", function(object, inps, ...){ }) setMethod("get_loglik", "ubmsFitOccu", function(object, inps, ...){ - psi <- t(posterior_linpred(object, submodel="state", transform=TRUE)) - p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) - get_loglik_occu(inps$y, inps$M, inps$si-1, psi, p, inps$Kmin) + psi <- calculate_par(object, inps, submodel="state") + p <- calculate_par(object, inps, submodel="det") + get_loglik_occu(inps$stan_data$y, inps$stan_data$M, inps$stan_data$si-1, + psi, p, inps$stan_data$Kmin) }) setMethod("get_loglik", "ubmsFitPcount", function(object, inps, ...){ - lam <- t(posterior_linpred(object, submodel="state", transform=TRUE)) - p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) - get_loglik_pcount(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin) + lam <- calculate_par(object, inps, submodel="state") + p <- calculate_par(object, inps, submodel="det") + get_loglik_pcount(inps$stan_data$y, inps$stan_data$M, inps$stan_data$si-1, + lam, p, inps$stan_data$K, inps$stan_data$Kmin) }) setMethod("get_loglik", "ubmsFitOccuRN", function(object, inps, ...){ - lam <- t(posterior_linpred(object, submodel="state", transform=TRUE)) - p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) - get_loglik_occuRN(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin) + lam <- calculate_par(object, inps, submodel="state") + p <- calculate_par(object, inps, submodel="det") + get_loglik_occuRN(inps$stan_data$y, inps$stan_data$M, inps$stan_data$si-1, + lam, p, inps$stan_data$K, inps$stan_data$Kmin) }) setMethod("get_loglik", "ubmsFitMultinomPois", function(object, inps, ...){ - lam <- t(posterior_linpred(object, submodel="state", transform=TRUE)) - p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) - get_loglik_multinomPois(inps$y, inps$M, inps$si-1, lam, p, inps$y_dist) + lam <- calculate_par(object, inps, submodel="state") + p <- calculate_par(object, inps, submodel="det") + get_loglik_multinomPois(inps$stan_data$y, inps$stan_data$M, inps$stan_data$si-1, + lam, p, inps$stan_data$y_dist) +}) + +setMethod("get_loglik", "ubmsFitColext", function(object, inps, ...){ + psi <- calculate_par(object, inps, submodel="state") + psicube <- array(NA, c(dim(psi), 2)) + psicube[,,1] <- 1 - psi + psicube[,,2] <- psi + psicube <- aperm(psicube, c(1,3,2)) + + col <- calculate_par(object, inps, submodel="col") + ext <- calculate_par(object, inps, submodel="ext") + phicube <- array(NA, c(dim(col), 4)) + phicube[,,1] <- 1 - col + phicube[,,2] <- col + phicube[,,3] <- ext + phicube[,,4] <- 1 - ext + phicube <- aperm(phicube, c(1,3,2)) + pmat <- calculate_par(object, inps, submodel="det") + + get_loglik_colext(inps$stan_data$y, inps$stan_data$M, inps$stan_data$Tsamp-1, + inps$stan_data$J, inps$stan_data$si-1, psicube, + phicube, pmat, 1 - inps$stan_data$Kmin) }) diff --git a/inst/stan/colext.stan b/inst/stan/colext.stan index e1aecf5..8d599da 100644 --- a/inst/stan/colext.stan +++ b/inst/stan/colext.stan @@ -42,15 +42,15 @@ real lp_colext(int[] y, int[] Tsamp, int[] J, row_vector psi, matrix phi_raw, if(T > 1){ for (t in 1:(T-1)){ phi = get_phi(phi_raw, Tsamp[t], Tsamp[t+1]); - end = idx + J[t] - 1; - Dpt = get_pY(y[idx:end], logit_p[idx:end], nd[t]); + end = idx + J[Tsamp[t]] - 1; + Dpt = get_pY(y[idx:end], logit_p[idx:end], nd[Tsamp[t]]); phi_prod *= diag_pre_multiply(Dpt, phi); - idx += J[t]; + idx += J[Tsamp[t]]; } } - end = idx + J[T] - 1; - Dpt = get_pY(y[idx:end], logit_p[idx:end], nd[T]); + end = idx + J[Tsamp[T]] - 1; + Dpt = get_pY(y[idx:end], logit_p[idx:end], nd[Tsamp[T]]); return log(dot_product(psi * phi_prod, Dpt)); } diff --git a/src/kfold.cpp b/src/kfold.cpp index deb4687..16b3b88 100644 --- a/src/kfold.cpp +++ b/src/kfold.cpp @@ -185,3 +185,103 @@ arma::mat get_loglik_multinomPois(arma::vec y, int M, arma::imat si, } return out; } + + +arma::vec get_pY(arma::vec y, arma::vec p, int nd){ + + vec out(2); + int J = y.size(); + out(0) = nd; + out(1) = 1.0; + for (int j = 0; j < J; j++){ + out(1) *= Rf_dbinom(y(j), 1, p(j), false); + } + return out; +} + +arma::mat phi_matrix(arma::rowvec phi_raw){ + mat out(2,2); + out(0,0) = phi_raw(0); + out(0,1) = phi_raw(1); + out(1,0) = phi_raw(2); + out(1,1) = phi_raw(3); + return out; +} + +arma::mat get_phi(arma::mat phi_raw, int Tstart, int Tnext){ + mat phi = eye(2,2); + int delta = Tnext - Tstart; + if(delta == 1){ + return phi_matrix(phi_raw.row(Tstart)); + } + + for (int d = 1; d < (delta + 1); d++){ + phi = phi * phi_matrix(phi_raw.row(Tstart + d - 1)); + } + return phi; +} + + + +// [[Rcpp::export]] +arma::mat get_loglik_colext(arma::vec y, int M, arma::ivec Tsamp, arma::imat J, + arma::imat si, arma::cube psicube, arma::cube phicube, + arma::mat pmat, arma::imat nd){ + + int S = pmat.n_cols; + mat out(S,M); + vec p_s(pmat.n_rows); + mat psi_s(psicube.n_rows, psicube.n_cols); + mat phi_s(phicube.n_rows, phicube.n_cols); + + rowvec psi_m; + mat phi_m; + ivec Tsamp_m; + irowvec J_m, nd_m; + vec y_m, p_m; + int T; + mat phi_prod(2,2); + mat phi(2,2); + vec Dpt(2); + int idx, end; + + for (int s = 0; s < S; s++){ + + p_s = pmat.col(s); + psi_s = psicube.slice(s); + phi_s = phicube.slice(s); + + for (int m = 0; m < M; m++){ + idx = 0; + y_m = y.subvec(si(m,0), si(m,1)); + Tsamp_m = Tsamp.subvec(si(m,2), si(m,3)); + J_m = J.row(m); + + psi_m = psi_s.row(m); + phi_m = phi_s.rows(si(m,4),si(m,5)); + p_m = p_s.subvec(si(m,0), si(m,1)); + nd_m = nd.row(m); + + T = Tsamp_m.size(); + phi_prod = eye(2,2); + + if(T > 1){ + for (int t = 0; t < (T-1); t++){ + phi = get_phi(phi_m, Tsamp_m(t), Tsamp_m(t+1)); + end = idx + J_m(Tsamp_m(t)) - 1; + Dpt = get_pY(y_m.subvec(idx, end), p_m.subvec(idx, end), + nd_m(Tsamp_m(t))); + + phi_prod *= diagmat(Dpt) * phi; + idx += J_m(Tsamp_m(t)); + } + } + + end = idx + J_m(Tsamp_m(T-1)) - 1; + Dpt = get_pY(y_m.subvec(idx,end), p_m.subvec(idx,end), nd_m(Tsamp_m(T-1))); + out(s, m) = log(dot(psi_m * phi_prod, Dpt)); + } + } + + return out; +} |