diff options
author | Ken Kellner <ken@kenkellner.com> | 2022-04-05 14:51:01 -0400 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2022-04-05 14:51:01 -0400 |
commit | 26c1b71b85887f4b275c4c5b5fb934264f2abc2c (patch) | |
tree | 3eaa2ac37e431c1d99981fff548eb625910e4cf7 | |
parent | f0976d0e19e696ae85c11019c527a6a81b4096e4 (diff) |
kfold multinomPois support
-rw-r--r-- | R/RcppExports.R | 4 | ||||
-rw-r--r-- | R/kfold.R | 6 | ||||
-rw-r--r-- | src/kfold.cpp | 64 |
3 files changed, 74 insertions, 0 deletions
diff --git a/R/RcppExports.R b/R/RcppExports.R index 63f325d..92bb1de 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -21,6 +21,10 @@ get_loglik_occuRN <- function(y, M, si, lammat, rmat, K, Kmin) { .Call(`_ubms_get_loglik_occuRN`, y, M, si, lammat, rmat, K, Kmin) } +get_loglik_multinomPois <- function(y, M, si, lammat, pmat, pi_type) { + .Call(`_ubms_get_loglik_multinomPois`, y, M, si, lammat, pmat, pi_type) +} + simz_pcount <- function(y, lam_post, p_post, K, Kmin, kvals) { .Call(`_ubms_simz_pcount`, y, lam_post, p_post, K, Kmin, kvals) } @@ -79,3 +79,9 @@ setMethod("get_loglik", "ubmsFitOccuRN", function(object, inps, ...){ p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) get_loglik_occuRN(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin) }) + +setMethod("get_loglik", "ubmsFitMultinomPois", function(object, inps, ...){ + lam <- t(posterior_linpred(object, submodel="state", transform=TRUE)) + p <- t(posterior_linpred(object, submodel="det", transform=TRUE)) + get_loglik_multinomPois(inps$y, inps$M, inps$si-1, lam, p, inps$y_dist) +}) diff --git a/src/kfold.cpp b/src/kfold.cpp index c9f75f8..deb4687 100644 --- a/src/kfold.cpp +++ b/src/kfold.cpp @@ -121,3 +121,67 @@ arma::mat get_loglik_occuRN(arma::vec y, int M, arma::imat si, return out; } + +// pi functions +arma::vec pi_removal(arma::vec p){ + int J = p.size(); + vec pi_out(J); + pi_out(0) = p(0); + for (int j = 1; j < J; j++){ + pi_out(j) = pi_out(j-1) / p(j-1) * (1-p(j-1)) * p(j); + } + return pi_out; +} + +arma::vec pi_double(arma::vec p){ + vec pi_out(3); + pi_out(0) = p(0) * (1 - p(1)); + pi_out(1) = p(1) * (1 - p(0)); + pi_out(2) = p(0) * p(1); + return pi_out; +} + +arma::vec pi_fun(int pi_type, arma::vec p, int J){ + vec out(J); + if(pi_type == 0){ + out = pi_double(p); + } else if(pi_type == 1){ + out = pi_removal(p); + } else{ + Rcpp::stop("Invalid pi function"); + } + return out; +} + +// [[Rcpp::export]] +arma::mat get_loglik_multinomPois(arma::vec y, int M, arma::imat si, + arma::mat lammat, arma::mat pmat, int pi_type){ + + int S = lammat.n_cols; + mat out = zeros(S,M); + vec p_s(pmat.n_rows); + vec ysub, psub; + int R = pmat.n_rows / M; + int pstart, pend, J; + vec cp; + + for (int s = 0; s < S; s++){ + + p_s = pmat.col(s); + pstart = 0; + + for (int m = 0; m < M; m++){ + pend = pstart + R - 1; + ysub = y.subvec(si(m,0), si(m,1)); + psub = p_s.subvec(pstart, pend); + J = ysub.size(); + cp = pi_fun(pi_type, psub, J); + + for (int j = 0; j < J; j++){ + out(s, m) += Rf_dpois(ysub(j), lammat(m, s) * cp(j), true); + } + pstart += R; + } + } + return out; +} |