diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-12-06 08:35:55 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-12-06 08:35:55 -0500 |
commit | 3f58d60478213394ba960813976ffa47640e65b2 (patch) | |
tree | ebf62c959f18bf4ee5507cbca29a5923ea53ad2c | |
parent | c8a1c83de33a909cf0dc167ee7b5e880e41610b1 (diff) |
Clean up jags and jagsbasic
-rw-r--r-- | R/jags.R | 105 | ||||
-rw-r--r-- | R/jagsbasic.R | 86 | ||||
-rw-r--r-- | R/process_input.R | 7 | ||||
-rw-r--r-- | R/rjags_tools.R | 33 |
4 files changed, 92 insertions, 139 deletions
@@ -1,77 +1,36 @@ - -jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,n.iter,n.burnin=0,n.thin=1, - modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL, - bugs.format=FALSE,verbose=TRUE){ +jagsUI <- jags <- function(data, inits=NULL, parameters.to.save, model.file, + n.chains, n.adapt=NULL, n.iter, n.burnin=0, n.thin=1, + modules=c('glm'), factories=NULL, parallel=FALSE, + n.cores=NULL, DIC=TRUE, store.data=FALSE, + codaOnly=FALSE,seed=NULL, bugs.format=FALSE,verbose=TRUE){ - if(!is.null(seed)){ - stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) - } - # Check input data - inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, - n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter, - n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, - DIC=DIC, quiet=!verbose, parallel=parallel) - data <- inps_check$data - parameters.to.save <- inps_check$params - inits <- inps_check$inits - mcmc.info <- inps_check$mcmc.info - if(parallel) n.cores <- inps_check$mcmc.info$n.cores - - #Save start time - start.time <- Sys.time() - - #Stuff to do if parallel=TRUE - if(parallel && n.chains>1){ - - par <- run.parallel(data,inits,parameters.to.save,model.file,n.chains,n.adapt,n.iter,n.burnin,n.thin, - modules=modules,factories=factories,DIC=DIC,verbose=verbose,n.cores=n.cores) - samples <- par$samples - m <- par$model - total.adapt <- par$total.adapt - sufficient.adapt <- par$sufficient.adapt - if(any(!sufficient.adapt)&verbose){warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt")} - - } else { - - ####################### - ##Run rjags functions## - ####################### - - #Set modules - set.modules(modules,DIC) - set.factories(factories) - - rjags.output <- run.model(model.file,data,inits,parameters.to.save,n.chains,n.iter,n.burnin,n.thin,n.adapt,verbose=verbose) - samples <- rjags.output$samples - m <- rjags.output$m - total.adapt <- rjags.output$total.adapt - sufficient.adapt <- rjags.output$sufficient.adapt - - ########################## - ##End of rjags functions## - ########################## - - } - - #Add mcmc info into list - mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) - mcmc.info$n.samples <- coda::niter(samples) * n.chains - mcmc.info$end.values <- samples[coda::niter(samples),] - mcmc.info$n.adapt <- total.adapt - mcmc.info$sufficient.adapt <- sufficient.adapt + inps <- process_input(data, parameters.to.save, inits, + n.chains, n.adapt, n.iter, n.burnin, n.thin, n.cores, + DIC, !verbose, parallel, seed) + + # Run JAGS via rjags + rjags_out <- run_rjags(inps$data, inps$inits, inps$params, model.file, + inps$mcmc.info, modules, factories, DIC, parallel, !verbose) - #Reorganize JAGS output to match input parameter order - samples <- order_samples(samples, parameters.to.save) - - #Convert rjags output to jagsUI form + #Update mcmc.info list + mcmc.info <- inps$mcmc.info + iter_final <- coda::niter(rjags_out$samples) + mcmc.info$elapsed.mins <- rjags_out$elapsed.mins + mcmc.info$n.samples <- iter_final * n.chains + mcmc.info$end.values <- rjags_out$samples[iter_final,] + mcmc.info$n.adapt <- rjags_out$total.adapt + mcmc.info$sufficient.adapt <- rjags_out$sufficient.adapt + + # Reorganize JAGS output to match input parameter order + samples <- order_samples(rjags_out$samples, inps$params) + # Process output and calculate statistics output <- process_output(samples, coda_only = codaOnly, DIC, quiet = !verbose) + # Fallback if processing output fails if(is.null(output)){ - output <- list() - output$samples <- samples - output$model <- m - output$n.cores <- n.cores + output <- list(samples = samples, model = rjags_out$m) + output$n.cores <- mcmc.info$n.cores class(output) <- 'jagsUIbasic' return(output) } @@ -81,13 +40,13 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain output$modfile <- model.file #If user wants to save input data/inits if(store.data){ - output$inits <- inits - output$data <- data + output$inits <- inps$inits + output$data <- inps$data } - output$model <- m - output$parameters <- parameters.to.save + output$model <- rjags_out$m + output$parameters <- inps$params output$mcmc.info <- mcmc.info - output$run.date <- start.time + output$run.date <- rjags_out$run.date output$parallel <- parallel output$bugs.format <- bugs.format output$calc.DIC <- DIC diff --git a/R/jagsbasic.R b/R/jagsbasic.R index bd7e3f0..d488ee2 100644 --- a/R/jagsbasic.R +++ b/R/jagsbasic.R @@ -1,74 +1,30 @@ -jags.basic <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,n.iter,n.burnin=0,n.thin=1, - modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,seed=NULL,save.model=FALSE,verbose=TRUE){ +jags.basic <- function(data, inits=NULL, parameters.to.save, model.file, + n.chains, n.adapt=NULL, n.iter, n.burnin=0, n.thin=1, + modules=c('glm'), factories=NULL, parallel=FALSE, + n.cores=NULL, DIC=TRUE, seed=NULL, save.model=FALSE, verbose=TRUE){ - if(!is.null(seed)){ - stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) - } - # Check input data - inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, - n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter, - n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, - DIC=DIC, quiet=!verbose, parallel=parallel) - data <- inps_check$data - parameters.to.save <- inps_check$params - inits <- inps_check$inits - mcmc.info <- inps_check$mcmc.info - if(parallel) n.cores <- inps_check$mcmc.info$n.cores - - #Save start time - start.time <- Sys.time() - - #Stuff to do if parallel=TRUE - if(parallel && n.chains>1){ - - par <- run.parallel(data,inits,parameters.to.save,model.file,n.chains,n.adapt,n.iter,n.burnin,n.thin, - modules=modules,factories=factories,DIC=DIC,verbose=verbose,n.cores=n.cores) - samples <- par$samples - m <- par$model - total.adapt <- par$total.adapt - sufficient.adapt <- par$sufficient.adapt - if(any(!sufficient.adapt)&verbose){warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt")} - - } else { - - ####################### - ##Run rjags functions## - ####################### - - #Set modules - set.modules(modules,DIC) - set.factories(factories) - - rjags.output <- run.model(model.file,data,inits,parameters.to.save,n.chains,n.iter,n.burnin,n.thin,n.adapt, - verbose=verbose) - samples <- rjags.output$samples - m <- rjags.output$m - total.adapt <- rjags.output$total.adapt - sufficient.adapt <- rjags.output$sufficient.adapt - - ########################## - ##End of rjags functions## - ########################## - - } + inps <- process_input(data, parameters.to.save, inits, + n.chains, n.adapt, n.iter, n.burnin, n.thin, n.cores, + DIC, !verbose, parallel, seed) - #Get more info about MCMC run - time <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) - if(verbose){cat('MCMC took',time,'minutes.\n')} + # Run JAGS via rjags + rjags_out <- run_rjags(inps$data, inps$inits, inps$params, model.file, + inps$mcmc.info, modules, factories, DIC, parallel, !verbose) + # Report time + if(verbose) cat('MCMC took', rjags_out$elapsed.mins, 'minutes.\n') + + # Create output object if(save.model){ - output <- list() - samples <- order_samples(samples, parameters.to.save) - output$samples <- samples - output$model <- m - output$n.cores <- n.cores - output$random.seed <- seed - class(output) <- 'jagsUIbasic' - } else {output <- samples} + samples <- order_samples(rjags_out$samples, inps$params) + output <- list(samples = samples, model = rjags_out$m) + output$n.cores <- inps$mcmc.info$n.cores + class(output) <- 'jagsUIbasic' + } else{ + output <- rjags_out$samples + } - return(output) - } diff --git a/R/process_input.R b/R/process_input.R index 39a2558..15991b2 100644 --- a/R/process_input.R +++ b/R/process_input.R @@ -1,8 +1,13 @@ # Process input---------------------------------------------------------------- process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burnin, - n_thin, n_cores, DIC, quiet, parallel){ + n_thin, n_cores, DIC, quiet, parallel, seed=NULL){ if(!quiet){cat('\nProcessing function input.......','\n')} + + if(!is.null(seed)){ + stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) + } + out <- list(data = check_data(data, quiet), params = check_params(params, DIC), inits = check_inits(inits, n_chains), diff --git a/R/rjags_tools.R b/R/rjags_tools.R new file mode 100644 index 0000000..c39c592 --- /dev/null +++ b/R/rjags_tools.R @@ -0,0 +1,33 @@ +run_rjags <- function(data, inits, params, modfile, mcmc_info, + modules, factories, DIC, parallel, quiet){ + + #Save start time + start.time <- Sys.time() + + mc <- mcmc_info + + # Run parallel + if(parallel & mc$n.chains > 1){ + result <- run.parallel(data, inits, params, modfile, + mc$n.chains, mc$n.adapt, mc$n.iter, mc$n.burnin, mc$n.thin, + modules=modules, factories=factories, DIC=DIC, + verbose=!quiet, n.cores=mc$n.cores) + # Move this down if not also handled in runmodel() + if(any(!result$sufficient.adapt)){ + warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt", call.=FALSE) + } + } else { + # Run non-parallel + set.modules(modules, DIC) + set.factories(factories) + + result <- run.model(modfile, data, inits, params, + mc$n.chains, mc$n.iter, mc$n.burnin, mc$n.thin, mc$n.adapt, + verbose=!quiet) + } + + result$run.date <- start.time + result$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) + + result +} |