aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <kenkellner@users.noreply.github.com>2024-03-03 11:08:27 -0500
committerGitHub <noreply@github.com>2024-03-03 11:08:27 -0500
commit8c06592ec90c7c19fc78c408bb30cae3aa5b618a (patch)
tree18bde2bf753622f5fe02d8a09cd7fc86c745ce04
parent85009abba68123ce0348ad3e4b7b1f9b8f7d134a (diff)
parent709d3d0bfd1e6719df04ec49d2f200357b0b063b (diff)
Merge pull request #276 from rbchan/gdistsamp_refactor2
Refactor gdistsamp likelihood to avoid occasional crashes
-rw-r--r--R/gdistsamp.R7
-rw-r--r--src/RcppExports.cpp6
-rw-r--r--src/nll_gdistsamp.cpp34
3 files changed, 31 insertions, 16 deletions
diff --git a/R/gdistsamp.R b/R/gdistsamp.R
index c43076a..f021c90 100644
--- a/R/gdistsamp.R
+++ b/R/gdistsamp.R
@@ -399,8 +399,9 @@ if(engine =="C"){
as.vector(t(out))
}
y_long <- long_format(y)
- kmytC <- kmyt
- kmytC[which(is.na(kmyt))] <- 0
+ # Vectorize these arrays as using arma::subcube sometimes crashes
+ kmytC <- as.vector(aperm(kmyt, c(3,2,1)))
+ lfac.kmytC <- as.vector(aperm(lfac.kmyt, c(3,2,1)))
if(output!='density'){
A <- rep(1, M)
}
@@ -411,7 +412,7 @@ if(engine =="C"){
nll <- function(params){
nll_gdistsamp(params, n_param, y_long, mixture_code, keyfun, survey,
Xlam, Xlam.offset, A, Xphi, Xphi.offset, Xdet, Xdet.offset,
- db, a, t(u), w, k, lfac.k, lfac.kmyt, kmyt, Kmin, threads)
+ db, a, t(u), w, k, lfac.k, lfac.kmytC, kmytC, Kmin, threads)
}
} else {
diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp
index 0fe45a8..455bc52 100644
--- a/src/RcppExports.cpp
+++ b/src/RcppExports.cpp
@@ -141,7 +141,7 @@ BEGIN_RCPP
END_RCPP
}
// nll_gdistsamp
-double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads);
+double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, arma::vec lfac_kmyt, arma::vec kmyt, arma::uvec Kmin, int threads);
RcppExport SEXP _unmarked_nll_gdistsamp(SEXP betaSEXP, SEXP n_paramSEXP, SEXP ySEXP, SEXP mixtureSEXP, SEXP keyfunSEXP, SEXP surveySEXP, SEXP XlamSEXP, SEXP Xlam_offsetSEXP, SEXP ASEXP, SEXP XphiSEXP, SEXP Xphi_offsetSEXP, SEXP XdetSEXP, SEXP Xdet_offsetSEXP, SEXP dbSEXP, SEXP aSEXP, SEXP uSEXP, SEXP wSEXP, SEXP kSEXP, SEXP lfac_kSEXP, SEXP lfac_kmytSEXP, SEXP kmytSEXP, SEXP KminSEXP, SEXP threadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
@@ -165,8 +165,8 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< arma::vec >::type w(wSEXP);
Rcpp::traits::input_parameter< arma::vec >::type k(kSEXP);
Rcpp::traits::input_parameter< arma::vec >::type lfac_k(lfac_kSEXP);
- Rcpp::traits::input_parameter< arma::cube >::type lfac_kmyt(lfac_kmytSEXP);
- Rcpp::traits::input_parameter< arma::cube >::type kmyt(kmytSEXP);
+ Rcpp::traits::input_parameter< arma::vec >::type lfac_kmyt(lfac_kmytSEXP);
+ Rcpp::traits::input_parameter< arma::vec >::type kmyt(kmytSEXP);
Rcpp::traits::input_parameter< arma::uvec >::type Kmin(KminSEXP);
Rcpp::traits::input_parameter< int >::type threads(threadsSEXP);
rcpp_result_gen = Rcpp::wrap(nll_gdistsamp(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads));
diff --git a/src/nll_gdistsamp.cpp b/src/nll_gdistsamp.cpp
index 80dc2d4..2ae9133 100644
--- a/src/nll_gdistsamp.cpp
+++ b/src/nll_gdistsamp.cpp
@@ -17,7 +17,7 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi,
arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db,
arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k,
- arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads){
+ arma::vec lfac_kmyt, arma::vec kmyt, arma::uvec Kmin, int threads){
#ifdef _OPENMP
omp_set_num_threads(threads);
@@ -27,7 +27,8 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
int T = Xphi.n_rows / M;
int R = y.size() / M;
unsigned J = R / T;
- int K = k.size() - 1;
+ int lk = k.size();
+ int K = lk - 1;
//Abundance
const vec lambda = exp(Xlam * beta_sub(beta, n_param, 0) + Xlam_offset) % A;
@@ -53,9 +54,18 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
int t_ind = i * T;
int y_ind = i * T * J;
+ int k_start = i * T * lk;
vec y_sub(J);
-
+ vec p(J);
+ vec p1(lk);
+ vec p3(J);
+ vec p4(lk);
+ double p5;
+
+ //Some unnecessary calculations here when k < Kmin
+ //These values are ignored later in calculation of site_lp
+ //However hard to avoid without refactoring entirely I think
mat mn = zeros(K+1, T);
for(int t=0; t<T; t++){
int y_stop = y_ind + J - 1;
@@ -64,23 +74,27 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
if(not_missing.size() == J){
- vec p1 = lfac_kmyt.subcube(span(i),span(t),span());
- vec p = distprob(keyfun, det_param(t_ind), scale, survey, db,
+ int k_stop = k_start + lk - 1;
+
+ p1 = lfac_kmyt.subvec(k_start, k_stop);
+
+ p = distprob(keyfun, det_param(t_ind), scale, survey, db,
w, a.row(i));
- vec p3 = p % u.col(i) * phi(t_ind);
- //the following line causes a segfault only in R CMD check,
- //when kmyt contains NA values
- vec p4 = kmyt.subcube(span(i),span(t),span());
+ p3 = p % u.col(i) * phi(t_ind);
+
+ p4 = kmyt.subvec(k_start, k_stop);
- double p5 = 1 - sum(p3);
+ p5 = 1 - sum(p3);
mn.col(t) = lfac_k - p1 + sum(y_sub % log(p3)) + p4 * log(p5);
}
t_ind += 1;
y_ind += J;
+ k_start += lk;
}
+ //Note that rows of mn for k < Kmin are skipped here
double site_lp = 0.0;
for (int j=Kmin(i); j<(K+1); j++){
site_lp += N_density(mixture, j, lambda(i), log_alpha) *