aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2023-12-06 08:35:55 -0500
committerKen Kellner <ken@kenkellner.com>2023-12-06 08:35:55 -0500
commit3f58d60478213394ba960813976ffa47640e65b2 (patch)
treeebf62c959f18bf4ee5507cbca29a5923ea53ad2c
parentc8a1c83de33a909cf0dc167ee7b5e880e41610b1 (diff)
Clean up jags and jagsbasic
-rw-r--r--R/jags.R105
-rw-r--r--R/jagsbasic.R86
-rw-r--r--R/process_input.R7
-rw-r--r--R/rjags_tools.R33
4 files changed, 92 insertions, 139 deletions
diff --git a/R/jags.R b/R/jags.R
index 5af4a6c..e6d4313 100644
--- a/R/jags.R
+++ b/R/jags.R
@@ -1,77 +1,36 @@
-
-jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,n.iter,n.burnin=0,n.thin=1,
- modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL,
- bugs.format=FALSE,verbose=TRUE){
+jagsUI <- jags <- function(data, inits=NULL, parameters.to.save, model.file,
+ n.chains, n.adapt=NULL, n.iter, n.burnin=0, n.thin=1,
+ modules=c('glm'), factories=NULL, parallel=FALSE,
+ n.cores=NULL, DIC=TRUE, store.data=FALSE,
+ codaOnly=FALSE,seed=NULL, bugs.format=FALSE,verbose=TRUE){
- if(!is.null(seed)){
- stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
- }
-
# Check input data
- inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
- n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter,
- n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
- DIC=DIC, quiet=!verbose, parallel=parallel)
- data <- inps_check$data
- parameters.to.save <- inps_check$params
- inits <- inps_check$inits
- mcmc.info <- inps_check$mcmc.info
- if(parallel) n.cores <- inps_check$mcmc.info$n.cores
-
- #Save start time
- start.time <- Sys.time()
-
- #Stuff to do if parallel=TRUE
- if(parallel && n.chains>1){
-
- par <- run.parallel(data,inits,parameters.to.save,model.file,n.chains,n.adapt,n.iter,n.burnin,n.thin,
- modules=modules,factories=factories,DIC=DIC,verbose=verbose,n.cores=n.cores)
- samples <- par$samples
- m <- par$model
- total.adapt <- par$total.adapt
- sufficient.adapt <- par$sufficient.adapt
- if(any(!sufficient.adapt)&verbose){warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt")}
-
- } else {
-
- #######################
- ##Run rjags functions##
- #######################
-
- #Set modules
- set.modules(modules,DIC)
- set.factories(factories)
-
- rjags.output <- run.model(model.file,data,inits,parameters.to.save,n.chains,n.iter,n.burnin,n.thin,n.adapt,verbose=verbose)
- samples <- rjags.output$samples
- m <- rjags.output$m
- total.adapt <- rjags.output$total.adapt
- sufficient.adapt <- rjags.output$sufficient.adapt
-
- ##########################
- ##End of rjags functions##
- ##########################
-
- }
-
- #Add mcmc info into list
- mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
- mcmc.info$n.samples <- coda::niter(samples) * n.chains
- mcmc.info$end.values <- samples[coda::niter(samples),]
- mcmc.info$n.adapt <- total.adapt
- mcmc.info$sufficient.adapt <- sufficient.adapt
+ inps <- process_input(data, parameters.to.save, inits,
+ n.chains, n.adapt, n.iter, n.burnin, n.thin, n.cores,
+ DIC, !verbose, parallel, seed)
+
+ # Run JAGS via rjags
+ rjags_out <- run_rjags(inps$data, inps$inits, inps$params, model.file,
+ inps$mcmc.info, modules, factories, DIC, parallel, !verbose)
- #Reorganize JAGS output to match input parameter order
- samples <- order_samples(samples, parameters.to.save)
-
- #Convert rjags output to jagsUI form
+ #Update mcmc.info list
+ mcmc.info <- inps$mcmc.info
+ iter_final <- coda::niter(rjags_out$samples)
+ mcmc.info$elapsed.mins <- rjags_out$elapsed.mins
+ mcmc.info$n.samples <- iter_final * n.chains
+ mcmc.info$end.values <- rjags_out$samples[iter_final,]
+ mcmc.info$n.adapt <- rjags_out$total.adapt
+ mcmc.info$sufficient.adapt <- rjags_out$sufficient.adapt
+
+ # Reorganize JAGS output to match input parameter order
+ samples <- order_samples(rjags_out$samples, inps$params)
+ # Process output and calculate statistics
output <- process_output(samples, coda_only = codaOnly, DIC, quiet = !verbose)
+ # Fallback if processing output fails
if(is.null(output)){
- output <- list()
- output$samples <- samples
- output$model <- m
- output$n.cores <- n.cores
+ output <- list(samples = samples, model = rjags_out$m)
+ output$n.cores <- mcmc.info$n.cores
class(output) <- 'jagsUIbasic'
return(output)
}
@@ -81,13 +40,13 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain
output$modfile <- model.file
#If user wants to save input data/inits
if(store.data){
- output$inits <- inits
- output$data <- data
+ output$inits <- inps$inits
+ output$data <- inps$data
}
- output$model <- m
- output$parameters <- parameters.to.save
+ output$model <- rjags_out$m
+ output$parameters <- inps$params
output$mcmc.info <- mcmc.info
- output$run.date <- start.time
+ output$run.date <- rjags_out$run.date
output$parallel <- parallel
output$bugs.format <- bugs.format
output$calc.DIC <- DIC
diff --git a/R/jagsbasic.R b/R/jagsbasic.R
index bd7e3f0..d488ee2 100644
--- a/R/jagsbasic.R
+++ b/R/jagsbasic.R
@@ -1,74 +1,30 @@
-jags.basic <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,n.iter,n.burnin=0,n.thin=1,
- modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,seed=NULL,save.model=FALSE,verbose=TRUE){
+jags.basic <- function(data, inits=NULL, parameters.to.save, model.file,
+ n.chains, n.adapt=NULL, n.iter, n.burnin=0, n.thin=1,
+ modules=c('glm'), factories=NULL, parallel=FALSE,
+ n.cores=NULL, DIC=TRUE, seed=NULL, save.model=FALSE, verbose=TRUE){
- if(!is.null(seed)){
- stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
- }
-
# Check input data
- inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
- n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter,
- n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
- DIC=DIC, quiet=!verbose, parallel=parallel)
- data <- inps_check$data
- parameters.to.save <- inps_check$params
- inits <- inps_check$inits
- mcmc.info <- inps_check$mcmc.info
- if(parallel) n.cores <- inps_check$mcmc.info$n.cores
-
- #Save start time
- start.time <- Sys.time()
-
- #Stuff to do if parallel=TRUE
- if(parallel && n.chains>1){
-
- par <- run.parallel(data,inits,parameters.to.save,model.file,n.chains,n.adapt,n.iter,n.burnin,n.thin,
- modules=modules,factories=factories,DIC=DIC,verbose=verbose,n.cores=n.cores)
- samples <- par$samples
- m <- par$model
- total.adapt <- par$total.adapt
- sufficient.adapt <- par$sufficient.adapt
- if(any(!sufficient.adapt)&verbose){warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt")}
-
- } else {
-
- #######################
- ##Run rjags functions##
- #######################
-
- #Set modules
- set.modules(modules,DIC)
- set.factories(factories)
-
- rjags.output <- run.model(model.file,data,inits,parameters.to.save,n.chains,n.iter,n.burnin,n.thin,n.adapt,
- verbose=verbose)
- samples <- rjags.output$samples
- m <- rjags.output$m
- total.adapt <- rjags.output$total.adapt
- sufficient.adapt <- rjags.output$sufficient.adapt
-
- ##########################
- ##End of rjags functions##
- ##########################
-
- }
+ inps <- process_input(data, parameters.to.save, inits,
+ n.chains, n.adapt, n.iter, n.burnin, n.thin, n.cores,
+ DIC, !verbose, parallel, seed)
- #Get more info about MCMC run
- time <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
- if(verbose){cat('MCMC took',time,'minutes.\n')}
+ # Run JAGS via rjags
+ rjags_out <- run_rjags(inps$data, inps$inits, inps$params, model.file,
+ inps$mcmc.info, modules, factories, DIC, parallel, !verbose)
+ # Report time
+ if(verbose) cat('MCMC took', rjags_out$elapsed.mins, 'minutes.\n')
+
+ # Create output object
if(save.model){
- output <- list()
- samples <- order_samples(samples, parameters.to.save)
- output$samples <- samples
- output$model <- m
- output$n.cores <- n.cores
- output$random.seed <- seed
- class(output) <- 'jagsUIbasic'
- } else {output <- samples}
+ samples <- order_samples(rjags_out$samples, inps$params)
+ output <- list(samples = samples, model = rjags_out$m)
+ output$n.cores <- inps$mcmc.info$n.cores
+ class(output) <- 'jagsUIbasic'
+ } else{
+ output <- rjags_out$samples
+ }
-
return(output)
-
}
diff --git a/R/process_input.R b/R/process_input.R
index 39a2558..15991b2 100644
--- a/R/process_input.R
+++ b/R/process_input.R
@@ -1,8 +1,13 @@
# Process input----------------------------------------------------------------
process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burnin,
- n_thin, n_cores, DIC, quiet, parallel){
+ n_thin, n_cores, DIC, quiet, parallel, seed=NULL){
if(!quiet){cat('\nProcessing function input.......','\n')}
+
+ if(!is.null(seed)){
+ stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
+ }
+
out <- list(data = check_data(data, quiet),
params = check_params(params, DIC),
inits = check_inits(inits, n_chains),
diff --git a/R/rjags_tools.R b/R/rjags_tools.R
new file mode 100644
index 0000000..c39c592
--- /dev/null
+++ b/R/rjags_tools.R
@@ -0,0 +1,33 @@
+run_rjags <- function(data, inits, params, modfile, mcmc_info,
+ modules, factories, DIC, parallel, quiet){
+
+ #Save start time
+ start.time <- Sys.time()
+
+ mc <- mcmc_info
+
+ # Run parallel
+ if(parallel & mc$n.chains > 1){
+ result <- run.parallel(data, inits, params, modfile,
+ mc$n.chains, mc$n.adapt, mc$n.iter, mc$n.burnin, mc$n.thin,
+ modules=modules, factories=factories, DIC=DIC,
+ verbose=!quiet, n.cores=mc$n.cores)
+ # Move this down if not also handled in runmodel()
+ if(any(!result$sufficient.adapt)){
+ warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt", call.=FALSE)
+ }
+ } else {
+ # Run non-parallel
+ set.modules(modules, DIC)
+ set.factories(factories)
+
+ result <- run.model(modfile, data, inits, params,
+ mc$n.chains, mc$n.iter, mc$n.burnin, mc$n.thin, mc$n.adapt,
+ verbose=!quiet)
+ }
+
+ result$run.date <- start.time
+ result$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
+
+ result
+}