aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2023-12-05 15:29:29 -0500
committerKen Kellner <ken@kenkellner.com>2023-12-05 15:29:29 -0500
commit07a1c53902871fefd08b19bfcbc824b95117e6a0 (patch)
tree3d2a735d86df8a7f3390aaa8547d9ed7129f27f9
parent9edbc8ce1fa292aa51f2418f1c1c0d9f298b5e82 (diff)
Add new input processing to autojags
-rw-r--r--R/autojags.R49
1 files changed, 31 insertions, 18 deletions
diff --git a/R/autojags.R b/R/autojags.R
index 15b0e9d..4c67f43 100644
--- a/R/autojags.R
+++ b/R/autojags.R
@@ -3,14 +3,29 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
save.all.iter=FALSE,modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL,
bugs.format=FALSE,Rhat.limit=1.1,max.iter=100000,verbose=TRUE){
- #Pass input data and parameter list through error check / processing
- data.check <- process.input(data,parameters.to.save,inits,n.chains,(n.burnin + iter.increment),
- n.burnin,n.thin,n.cores,DIC=DIC,autojags=TRUE,max.iter=max.iter,
- verbose=verbose,parallel=parallel,seed=seed)
- data <- data.check$data
- parameters.to.save <- data.check$params
- inits <- data.check$inits
- if(parallel){n.cores <- data.check$n.cores}
+ if(!is.null(seed)){
+ stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
+ }
+ if(n.chains<2) stop('Number of chains must be >1 to calculate Rhat.')
+ if((max.iter < n.burnin) & verbose){
+ old_warn <- options()$warn
+ options(warn=1)
+ warning('Maximum iterations includes burn-in and should be larger than burn-in.')
+ options(warn=old_warn)
+ }
+
+ # Check input data
+ inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
+ n_chains=n.chains, n_adapt=n.adapt,
+ n_iter=(n.burnin + iter.increment),
+ n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
+ DIC=DIC, quiet=!verbose, parallel=parallel)
+ data <- inps_check$data
+ parameters.to.save <- inps_check$params
+ inits <- inps_check$inits
+ mcmc.info <- inps_check$mcmc.info
+ mcmc.info$end.values <- NULL # this is not saved in autojags for some reason
+ if(parallel) n.cores <- inps_check$mcmc.info$n.cores
#Save start time
start.time <- Sys.time()
@@ -50,10 +65,11 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
}
#Combine mcmc info into list
- n.samples <- dim(samples[[1]])[1] * n.chains
- mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt=sufficient.adapt,n.iter=(n.burnin + iter.increment),n.burnin,n.thin,n.samples,time)
- names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','elapsed.mins')
- if(parallel){mcmc.info$n.cores <- n.cores}
+ mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
+ mcmc.info$n.samples <- coda::niter(samples) * n.chains
+ mcmc.info$n.adapt <- total.adapt
+ mcmc.info$sufficient.adapt <- sufficient.adapt
+ mcmc.info$n.iter <- n.burnin + iter.increment
test <- test.Rhat(samples,Rhat.limit,codaOnly,verbose=verbose)
reach.max <- FALSE
@@ -113,7 +129,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
if(!save.all.iter){mcmc.info$n.burnin <- mcmc.info$n.iter}
mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment
- mcmc.info$n.samples <- dim(samples[[1]])[1] * n.chains
+ mcmc.info$n.samples <- coda::niter(samples) * n.chains
mcmc.info$sufficient.adapt <- sufficient.adapt
if(mcmc.info$n.iter>=max.iter){
@@ -123,9 +139,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
}
#Get more info about MCMC run
- end.time <- Sys.time()
- mcmc.info$elapsed.mins <- round(as.numeric(end.time-start.time,units="mins"),digits=3)
- date <- start.time
+ mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
#Reorganize JAGS output to match input parameter order
samples <- order_samples(samples, parameters.to.save)
@@ -152,8 +166,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
output$model <- mod
output$parameters <- parameters.to.save
output$mcmc.info <- mcmc.info
- output$run.date <- date
- output$random.seed <- seed
+ output$run.date <- start.time
output$parallel <- parallel
output$bugs.format <- bugs.format
output$calc.DIC <- DIC