diff options
Diffstat (limited to 'src/nll_gdistsamp.cpp')
-rw-r--r-- | src/nll_gdistsamp.cpp | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/src/nll_gdistsamp.cpp b/src/nll_gdistsamp.cpp index 80dc2d4..2ae9133 100644 --- a/src/nll_gdistsamp.cpp +++ b/src/nll_gdistsamp.cpp @@ -17,7 +17,7 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, - arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads){ + arma::vec lfac_kmyt, arma::vec kmyt, arma::uvec Kmin, int threads){ #ifdef _OPENMP omp_set_num_threads(threads); @@ -27,7 +27,8 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int T = Xphi.n_rows / M; int R = y.size() / M; unsigned J = R / T; - int K = k.size() - 1; + int lk = k.size(); + int K = lk - 1; //Abundance const vec lambda = exp(Xlam * beta_sub(beta, n_param, 0) + Xlam_offset) % A; @@ -53,9 +54,18 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int t_ind = i * T; int y_ind = i * T * J; + int k_start = i * T * lk; vec y_sub(J); - + vec p(J); + vec p1(lk); + vec p3(J); + vec p4(lk); + double p5; + + //Some unnecessary calculations here when k < Kmin + //These values are ignored later in calculation of site_lp + //However hard to avoid without refactoring entirely I think mat mn = zeros(K+1, T); for(int t=0; t<T; t++){ int y_stop = y_ind + J - 1; @@ -64,23 +74,27 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, if(not_missing.size() == J){ - vec p1 = lfac_kmyt.subcube(span(i),span(t),span()); - vec p = distprob(keyfun, det_param(t_ind), scale, survey, db, + int k_stop = k_start + lk - 1; + + p1 = lfac_kmyt.subvec(k_start, k_stop); + + p = distprob(keyfun, det_param(t_ind), scale, survey, db, w, a.row(i)); - vec p3 = p % u.col(i) * phi(t_ind); - //the following line causes a segfault only in R CMD check, - //when kmyt contains NA values - vec p4 = kmyt.subcube(span(i),span(t),span()); + p3 = p % u.col(i) * phi(t_ind); + + p4 = kmyt.subvec(k_start, k_stop); - double p5 = 1 - sum(p3); + p5 = 1 - sum(p3); mn.col(t) = lfac_k - p1 + sum(y_sub % log(p3)) + p4 * log(p5); } t_ind += 1; y_ind += J; + k_start += lk; } + //Note that rows of mn for k < Kmin are skipped here double site_lp = 0.0; for (int j=Kmin(i); j<(K+1); j++){ site_lp += N_density(mixture, j, lambda(i), log_alpha) * |