From 827508a5ef645e7bd3a0d76fd94acb4d9c3e3db4 Mon Sep 17 00:00:00 2001 From: Ken Kellner Date: Sat, 2 Mar 2024 20:55:17 -0500 Subject: Refactor gdistsamp likelihood for stability --- R/RcppExports.R | 4 +-- R/gdistsamp.R | 2 +- src/RcppExports.cpp | 13 ++++------ src/nll_gdistsamp.cpp | 54 +++++++++++++++++++++++------------------ tests/testthat/test_gdistsamp.R | 8 +++--- 5 files changed, 42 insertions(+), 39 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 70b57ac..7aba254 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -21,8 +21,8 @@ nll_gdistremoval <- function(beta, n_param, yDistance, yRemoval, ysum, mixture, .Call(`_unmarked_nll_gdistremoval`, beta, n_param, yDistance, yRemoval, ysum, mixture, keyfun, Xlam, A, Xphi, Xrem, Xdist, db, a, u, w, pl, K, Kmin, threads) } -nll_gdistsamp <- function(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads) { - .Call(`_unmarked_nll_gdistsamp`, beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads) +nll_gdistsamp <- function(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, K, Kmin, threads) { + .Call(`_unmarked_nll_gdistsamp`, beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, K, Kmin, threads) } nll_gmultmix <- function(beta, n_param, y, mixture, pi_fun, Xlam, Xlam_offset, Xphi, Xphi_offset, Xdet, Xdet_offset, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads) { diff --git a/R/gdistsamp.R b/R/gdistsamp.R index c43076a..5342b8e 100644 --- a/R/gdistsamp.R +++ b/R/gdistsamp.R @@ -411,7 +411,7 @@ if(engine =="C"){ nll <- function(params){ nll_gdistsamp(params, n_param, y_long, mixture_code, keyfun, survey, Xlam, Xlam.offset, A, Xphi, Xphi.offset, Xdet, Xdet.offset, - db, a, t(u), w, k, lfac.k, lfac.kmyt, kmyt, Kmin, threads) + db, a, t(u), w, K, Kmin, threads) } } else { diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 0fe45a8..69c817e 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -141,8 +141,8 @@ BEGIN_RCPP END_RCPP } // nll_gdistsamp -double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads); -RcppExport SEXP _unmarked_nll_gdistsamp(SEXP betaSEXP, SEXP n_paramSEXP, SEXP ySEXP, SEXP mixtureSEXP, SEXP keyfunSEXP, SEXP surveySEXP, SEXP XlamSEXP, SEXP Xlam_offsetSEXP, SEXP ASEXP, SEXP XphiSEXP, SEXP Xphi_offsetSEXP, SEXP XdetSEXP, SEXP Xdet_offsetSEXP, SEXP dbSEXP, SEXP aSEXP, SEXP uSEXP, SEXP wSEXP, SEXP kSEXP, SEXP lfac_kSEXP, SEXP lfac_kmytSEXP, SEXP kmytSEXP, SEXP KminSEXP, SEXP threadsSEXP) { +double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, arma::mat a, arma::mat u, arma::vec w, int K, arma::uvec Kmin, int threads); +RcppExport SEXP _unmarked_nll_gdistsamp(SEXP betaSEXP, SEXP n_paramSEXP, SEXP ySEXP, SEXP mixtureSEXP, SEXP keyfunSEXP, SEXP surveySEXP, SEXP XlamSEXP, SEXP Xlam_offsetSEXP, SEXP ASEXP, SEXP XphiSEXP, SEXP Xphi_offsetSEXP, SEXP XdetSEXP, SEXP Xdet_offsetSEXP, SEXP dbSEXP, SEXP aSEXP, SEXP uSEXP, SEXP wSEXP, SEXP KSEXP, SEXP KminSEXP, SEXP threadsSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -163,13 +163,10 @@ BEGIN_RCPP Rcpp::traits::input_parameter< arma::mat >::type a(aSEXP); Rcpp::traits::input_parameter< arma::mat >::type u(uSEXP); Rcpp::traits::input_parameter< arma::vec >::type w(wSEXP); - Rcpp::traits::input_parameter< arma::vec >::type k(kSEXP); - Rcpp::traits::input_parameter< arma::vec >::type lfac_k(lfac_kSEXP); - Rcpp::traits::input_parameter< arma::cube >::type lfac_kmyt(lfac_kmytSEXP); - Rcpp::traits::input_parameter< arma::cube >::type kmyt(kmytSEXP); + Rcpp::traits::input_parameter< int >::type K(KSEXP); Rcpp::traits::input_parameter< arma::uvec >::type Kmin(KminSEXP); Rcpp::traits::input_parameter< int >::type threads(threadsSEXP); - rcpp_result_gen = Rcpp::wrap(nll_gdistsamp(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, k, lfac_k, lfac_kmyt, kmyt, Kmin, threads)); + rcpp_result_gen = Rcpp::wrap(nll_gdistsamp(beta, n_param, y, mixture, keyfun, survey, Xlam, Xlam_offset, A, Xphi, Xphi_offset, Xdet, Xdet_offset, db, a, u, w, K, Kmin, threads)); return rcpp_result_gen; END_RCPP } @@ -573,7 +570,7 @@ static const R_CallMethodDef CallEntries[] = { {"_unmarked_nll_distsamp", (DL_FUNC) &_unmarked_nll_distsamp, 10}, {"_unmarked_nll_distsampOpen", (DL_FUNC) &_unmarked_nll_distsampOpen, 42}, {"_unmarked_nll_gdistremoval", (DL_FUNC) &_unmarked_nll_gdistremoval, 20}, - {"_unmarked_nll_gdistsamp", (DL_FUNC) &_unmarked_nll_gdistsamp, 23}, + {"_unmarked_nll_gdistsamp", (DL_FUNC) &_unmarked_nll_gdistsamp, 20}, {"_unmarked_nll_gmultmix", (DL_FUNC) &_unmarked_nll_gmultmix, 17}, {"_unmarked_nll_gpcount", (DL_FUNC) &_unmarked_nll_gpcount, 15}, {"_unmarked_nll_multinomPois", (DL_FUNC) &_unmarked_nll_multinomPois, 10}, diff --git a/src/nll_gdistsamp.cpp b/src/nll_gdistsamp.cpp index 80dc2d4..baed007 100644 --- a/src/nll_gdistsamp.cpp +++ b/src/nll_gdistsamp.cpp @@ -16,8 +16,7 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int mixture, std::string keyfun, std::string survey, arma::mat Xlam, arma::vec Xlam_offset, arma::vec A, arma::mat Xphi, arma::vec Xphi_offset, arma::mat Xdet, arma::vec Xdet_offset, arma::vec db, - arma::mat a, arma::mat u, arma::vec w, arma::vec k, arma::vec lfac_k, - arma::cube lfac_kmyt, arma::cube kmyt, arma::uvec Kmin, int threads){ + arma::mat a, arma::mat u, arma::vec w, int K, arma::uvec Kmin, int threads){ #ifdef _OPENMP omp_set_num_threads(threads); @@ -25,9 +24,7 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, int M = Xlam.n_rows; int T = Xphi.n_rows / M; - int R = y.size() / M; - unsigned J = R / T; - int K = k.size() - 1; + unsigned J = db.size() - 1; //Abundance const vec lambda = exp(Xlam * beta_sub(beta, n_param, 0) + Xlam_offset) % A; @@ -51,43 +48,52 @@ double nll_gdistsamp(arma::vec beta, arma::uvec n_param, arma::vec y, #pragma omp parallel for reduction(+: loglik) if(threads > 1) for (int i=0; i