aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2024-03-02 20:55:17 -0500
committerKen Kellner <ken@kenkellner.com>2024-03-02 20:55:17 -0500
commit827508a5ef645e7bd3a0d76fd94acb4d9c3e3db4 (patch)
tree3270b0d758a457428956cb07ab3c8e1847c86b6b
parent85009abba68123ce0348ad3e4b7b1f9b8f7d134a (diff)
Refactor gdistsamp likelihood for stabilitygdistsamp_refactor
-rw-r--r--R/RcppExports.R4
-rw-r--r--R/gdistsamp.R2
-rw-r--r--src/RcppExports.cpp13
-rw-r--r--src/nll_gdistsamp.cpp54
-rw-r--r--tests/testthat/test_gdistsamp.R8
5 files changed, 42 insertions, 39 deletions
diff --git a/R/RcppExports.R b/R/RcppExports.R
index 70b57ac..7aba254 100644
--- a/R/RcppExports.R
+++ b/R/RcppExports.R
@@ -21,8 +21,8 @@ nll_gdistremoval <- function(beta, n_param, yDistance, yRemoval, ysum, mixture,
.Call(`_unmarked_nll_gdistremoval`, beta, n_param, yDistance, yRemoval, ysum, mixture, keyfun, Xlam, A, Xphi, Xrem, Xdist, db, a, u, w, pl, K, Kmin, threads)
}
-nll_gdistsamp <- function(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads) {
- .Call(`_unmarked_nll_gdistsamp`, beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads)
+nll_gdistsamp <- function(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, K, Kmin, threads) {
+ .Call(`_unmarked_nll_gdistsamp`, beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, K, Kmin, threads)
}
nll_gmultmix <- function(beta, n_param, y, mixture, pi_fun, Xlam, Xlam_offset, Xphi, Xphi_offset, Xdet, Xdet_offset, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads) {
diff --git a/R/gdistsamp.R b/R/gdistsamp.R
index c43076a..5342b8e 100644
--- a/R/gdistsamp.R
+++ b/R/gdistsamp.R
@@ -411,7 +411,7 @@ if(engine =="C"){
nll <- function(params){
nll_gdistsamp(params, n_param, y_long, mixture_code, keyfun, survey,
Xlam, Xlam.offset, A, Xphi, Xphi.offset, Xdet, Xdet.offset,
- db, a, t(u), w, k, lfac.k, lfac.kmyt, kmyt, Kmin, threads)
+ db, a, t(u), w, K, Kmin, threads)
}
} else {
diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp
index 0fe45a8..69c817e 100644
--- a/src/RcppExports.cpp
+++ b/src/RcppExports.cpp
@@ -141,8 +141,8 @@ BEGIN_RCPP
END_RCPP
}
// nll_gdistsamp
-double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads);
-RcppExport SEXP _unmarked_nll_gdistsamp(SEXP betaSEXP, SEXP n_paramSEXP, SEXP ySEXP, SEXP mixtureSEXP, SEXP keyfunSEXP, SEXP surveySEXP, SEXP XlamSEXP, SEXP Xlam_offsetSEXP, SEXP ASEXP, SEXP XphiSEXP, SEXP Xphi_offsetSEXP, SEXP XdetSEXP, SEXP Xdet_offsetSEXP, SEXP dbSEXP, SEXP aSEXP, SEXP uSEXP, SEXP wSEXP, SEXP kSEXP, SEXP lfac_kSEXP, SEXP lfac_kmytSEXP, SEXP kmytSEXP, SEXP KminSEXP, SEXP threadsSEXP) {
+double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, int K, arma::uvec Kmin, int threads);
+RcppExport SEXP _unmarked_nll_gdistsamp(SEXP betaSEXP, SEXP n_paramSEXP, SEXP ySEXP, SEXP mixtureSEXP, SEXP keyfunSEXP, SEXP surveySEXP, SEXP XlamSEXP, SEXP Xlam_offsetSEXP, SEXP ASEXP, SEXP XphiSEXP, SEXP Xphi_offsetSEXP, SEXP XdetSEXP, SEXP Xdet_offsetSEXP, SEXP dbSEXP, SEXP aSEXP, SEXP uSEXP, SEXP wSEXP, SEXP KSEXP, SEXP KminSEXP, SEXP threadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -163,13 +163,10 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< arma::mat >::type a(aSEXP);
Rcpp::traits::input_parameter< arma::mat >::type u(uSEXP);
Rcpp::traits::input_parameter< arma::vec >::type w(wSEXP);
- Rcpp::traits::input_parameter< arma::vec >::type k(kSEXP);
- Rcpp::traits::input_parameter< arma::vec >::type lfac_k(lfac_kSEXP);
- Rcpp::traits::input_parameter< arma::cube >::type lfac_kmyt(lfac_kmytSEXP);
- Rcpp::traits::input_parameter< arma::cube >::type kmyt(kmytSEXP);
+ Rcpp::traits::input_parameter< int >::type K(KSEXP);
Rcpp::traits::input_parameter< arma::uvec >::type Kmin(KminSEXP);
Rcpp::traits::input_parameter< int >::type threads(threadsSEXP);
- rcpp_result_gen = Rcpp::wrap(nll_gdistsamp(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads));
+ rcpp_result_gen = Rcpp::wrap(nll_gdistsamp(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, K, Kmin, threads));
return rcpp_result_gen;
END_RCPP
}
@@ -573,7 +570,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_unmarked_nll_distsamp", (DL_FUNC) &_unmarked_nll_distsamp, 10},
{"_unmarked_nll_distsampOpen", (DL_FUNC) &_unmarked_nll_distsampOpen, 42},
{"_unmarked_nll_gdistremoval", (DL_FUNC) &_unmarked_nll_gdistremoval, 20},
- {"_unmarked_nll_gdistsamp", (DL_FUNC) &_unmarked_nll_gdistsamp, 23},
+ {"_unmarked_nll_gdistsamp", (DL_FUNC) &_unmarked_nll_gdistsamp, 20},
{"_unmarked_nll_gmultmix", (DL_FUNC) &_unmarked_nll_gmultmix, 17},
{"_unmarked_nll_gpcount", (DL_FUNC) &_unmarked_nll_gpcount, 15},
{"_unmarked_nll_multinomPois", (DL_FUNC) &_unmarked_nll_multinomPois, 10},
diff --git a/src/nll_gdistsamp.cpp b/src/nll_gdistsamp.cpp
index 80dc2d4..baed007 100644
--- a/src/nll_gdistsamp.cpp
+++ b/src/nll_gdistsamp.cpp
@@ -16,8 +16,7 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
int mixture, std::string keyfun, std::string survey,
arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi,
arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db,
- arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k,
- arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads){
+ arma::mat a, arma::mat u, arma::vec w, int K, arma::uvec Kmin, int threads){
#ifdef _OPENMP
omp_set_num_threads(threads);
@@ -25,9 +24,7 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
int M = Xlam.n_rows;
int T = Xphi.n_rows / M;
- int R = y.size() / M;
- unsigned J = R / T;
- int K = k.size() - 1;
+ unsigned J = db.size() - 1;
//Abundance
const vec lambda = exp(Xlam * beta_sub(beta, n_param, 0) + Xlam_offset) % A;
@@ -51,43 +48,52 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y,
#pragma omp parallel for reduction(+: loglik) if(threads > 1)
for (int i=0; i<M; i++){
+ vec f = zeros(K+1); // should this be ones(K+1);
+ for (int k=Kmin(i); k<(K+1); k++){
+ f(k) = N_density(mixture, k, lambda(i), log_alpha);
+ }
+
int t_ind = i * T;
int y_ind = i * T * J;
vec y_sub(J);
+ vec y_all(J+1);
+
+ vec cp(J);
+ vec cp_all(J+1);
+ double ptotal;
+
+ vec g = zeros(K+1);
- mat mn = zeros(K+1, T);
for(int t=0; t<T; t++){
int y_stop = y_ind + J - 1;
y_sub = y.subvec(y_ind, y_stop);
- uvec not_missing = find_finite(y_sub);
+ y_all.subvec(0, (J-1)) = y_sub;
- if(not_missing.size() == J){
+ uvec nm = find_finite(y_sub);
- vec p1 = lfac_kmyt.subcube(span(i),span(t),span());
- vec p = distprob(keyfun, det_param(t_ind), scale, survey, db,
- w, a.row(i));
- vec p3 = p % u.col(i) * phi(t_ind);
- //the following line causes a segfault only in R CMD check,
- //when kmyt contains NA values
- vec p4 = kmyt.subcube(span(i),span(t),span());
+ if((nm.size() == J)){
- double p5 = 1 - sum(p3);
+ cp = distprob(keyfun, det_param(t_ind), scale, survey, db,
+ w, a.row(i));
+ cp = cp % u.col(i) * phi(t_ind);
+ ptotal = sum(cp);
+
+ cp_all.subvec(0, (J-1)) = cp;
+ cp_all(J) = 1 - ptotal;
+
+ for (int k=Kmin(i); k<(K+1); k++){
+ y_all(J) = k - sum(y_sub);
+ g(k) += dmultinom(y_all, cp_all);
+ }
- mn.col(t) = lfac_k - p1 + sum(y_sub % log(p3)) + p4 * log(p5);
}
t_ind += 1;
y_ind += J;
}
- double site_lp = 0.0;
- for (int j=Kmin(i); j<(K+1); j++){
- site_lp += N_density(mixture, j, lambda(i), log_alpha) *
- exp(sum(mn.row(j)));
- }
-
- loglik += log(site_lp + DBL_MIN);
+ loglik += log(sum(f % exp(g)));
}
diff --git a/tests/testthat/test_gdistsamp.R b/tests/testthat/test_gdistsamp.R
index 1b232b2..e385566 100644
--- a/tests/testthat/test_gdistsamp.R
+++ b/tests/testthat/test_gdistsamp.R
@@ -351,7 +351,7 @@ test_that("gdistsamp with exp keyfunction works",{
keyfun="exp",engine="C", control=list(maxit=1))
fm_R <- gdistsamp(~par1, ~par2, ~par3, umf, output="density", se=FALSE,
keyfun="exp",engine="R", control=list(maxit=1))
- expect_equal(fm_C@AIC, fm_R@AIC)
+ expect_equal(coef(fm_C), coef(fm_R))
#fm_R <- gdistsamp(~par1, ~par2, ~par3, umf, output="density", se=FALSE,
# keyfun="exp",engine="R")
@@ -390,7 +390,7 @@ test_that("gdistsamp with exp keyfunction works",{
keyfun="exp",engine="C", se=F,control=list(maxit=1))
fm_R <- gdistsamp(~elevation, ~1, ~chaparral, jayumf, output='density',
keyfun="exp",engine="R", se=F, control=list(maxit=1))
- expect_equal(fm_C@AIC, fm_R@AIC, tol=1e-5)
+ expect_equal(coef(fm_C), coef(fm_R), tol=1e-5)
})
@@ -439,7 +439,7 @@ test_that("gdistsamp with hazard keyfunction works",{
keyfun="hazard",engine="C", control=list(maxit=1))
fm_R <- gdistsamp(~par1, ~par2, ~par3, umf, output="density", se=FALSE,
keyfun="hazard",engine="R", control=list(maxit=1))
- expect_equal(fm_C@AIC, fm_R@AIC, tol=1e-5)
+ expect_equal(coef(fm_C), coef(fm_R), tol=1e-5)
#fm_R <- gdistsamp(~par1, ~par2, ~par3, umf, output="density", se=FALSE,
# keyfun="hazard",engine="R")
@@ -478,7 +478,7 @@ test_that("gdistsamp with hazard keyfunction works",{
keyfun="hazard",engine="C", se=F,control=list(maxit=1))
fm_R <- gdistsamp(~elevation, ~1, ~chaparral, jayumf, output='density',
keyfun="hazard",engine="R", se=F, control=list(maxit=1))
- expect_equal(fm_C@AIC, fm_R@AIC, tol=1e-3)
+ expect_equal(coef(fm_C), coef(fm_R), tol=1e-2) # why are these slightly different?
})
test_that("predict works for gdistsamp",{