aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2022-04-05 14:51:01 -0400
committerKen Kellner <ken@kenkellner.com>2022-04-05 14:51:01 -0400
commit26c1b71b85887f4b275c4c5b5fb934264f2abc2c (patch)
tree3eaa2ac37e431c1d99981fff548eb625910e4cf7
parentf0976d0e19e696ae85c11019c527a6a81b4096e4 (diff)
kfold multinomPois support
-rw-r--r--R/RcppExports.R4
-rw-r--r--R/kfold.R6
-rw-r--r--src/kfold.cpp64
3 files changed, 74 insertions, 0 deletions
diff --git a/R/RcppExports.R b/R/RcppExports.R
index 63f325d..92bb1de 100644
--- a/R/RcppExports.R
+++ b/R/RcppExports.R
@@ -21,6 +21,10 @@ get_loglik_occuRN <- function(y, M, si, lammat, rmat, K, Kmin) {
.Call(`_ubms_get_loglik_occuRN`, y, M, si, lammat, rmat, K, Kmin)
}
+get_loglik_multinomPois <- function(y, M, si, lammat, pmat, pi_type) {
+ .Call(`_ubms_get_loglik_multinomPois`, y, M, si, lammat, pmat, pi_type)
+}
+
simz_pcount <- function(y, lam_post, p_post, K, Kmin, kvals) {
.Call(`_ubms_simz_pcount`, y, lam_post, p_post, K, Kmin, kvals)
}
diff --git a/R/kfold.R b/R/kfold.R
index a5aa504..c1c3602 100644
--- a/R/kfold.R
+++ b/R/kfold.R
@@ -79,3 +79,9 @@ setMethod("get_loglik", "ubmsFitOccuRN", function(object, inps, ...){
p <- t(posterior_linpred(object, submodel="det", transform=TRUE))
get_loglik_occuRN(inps$y, inps$M, inps$si-1, lam, p, inps$K, inps$Kmin)
})
+
+setMethod("get_loglik", "ubmsFitMultinomPois", function(object, inps, ...){
+ lam <- t(posterior_linpred(object, submodel="state", transform=TRUE))
+ p <- t(posterior_linpred(object, submodel="det", transform=TRUE))
+ get_loglik_multinomPois(inps$y, inps$M, inps$si-1, lam, p, inps$y_dist)
+})
diff --git a/src/kfold.cpp b/src/kfold.cpp
index c9f75f8..deb4687 100644
--- a/src/kfold.cpp
+++ b/src/kfold.cpp
@@ -121,3 +121,67 @@ arma::mat get_loglik_occuRN(arma::vec y, int M, arma::imat si,
return out;
}
+
+// pi functions
+arma::vec pi_removal(arma::vec p){
+ int J = p.size();
+ vec pi_out(J);
+ pi_out(0) = p(0);
+ for (int j = 1; j < J; j++){
+ pi_out(j) = pi_out(j-1) / p(j-1) * (1-p(j-1)) * p(j);
+ }
+ return pi_out;
+}
+
+arma::vec pi_double(arma::vec p){
+ vec pi_out(3);
+ pi_out(0) = p(0) * (1 - p(1));
+ pi_out(1) = p(1) * (1 - p(0));
+ pi_out(2) = p(0) * p(1);
+ return pi_out;
+}
+
+arma::vec pi_fun(int pi_type, arma::vec p, int J){
+ vec out(J);
+ if(pi_type == 0){
+ out = pi_double(p);
+ } else if(pi_type == 1){
+ out = pi_removal(p);
+ } else{
+ Rcpp::stop("Invalid pi function");
+ }
+ return out;
+}
+
+// [[Rcpp::export]]
+arma::mat get_loglik_multinomPois(arma::vec y, int M, arma::imat si,
+ arma::mat lammat, arma::mat pmat, int pi_type){
+
+ int S = lammat.n_cols;
+ mat out = zeros(S,M);
+ vec p_s(pmat.n_rows);
+ vec ysub, psub;
+ int R = pmat.n_rows / M;
+ int pstart, pend, J;
+ vec cp;
+
+ for (int s = 0; s < S; s++){
+
+ p_s = pmat.col(s);
+ pstart = 0;
+
+ for (int m = 0; m < M; m++){
+ pend = pstart + R - 1;
+ ysub = y.subvec(si(m,0), si(m,1));
+ psub = p_s.subvec(pstart, pend);
+ J = ysub.size();
+ cp = pi_fun(pi_type, psub, J);
+
+ for (int j = 0; j < J; j++){
+ out(s, m) += Rf_dpois(ysub(j), lammat(m, s) * cp(j), true);
+ }
+ pstart += R;
+ }
+ }
+ return out;
+}