diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-12-05 15:29:29 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-12-05 15:29:29 -0500 |
commit | 07a1c53902871fefd08b19bfcbc824b95117e6a0 (patch) | |
tree | 3d2a735d86df8a7f3390aaa8547d9ed7129f27f9 | |
parent | 9edbc8ce1fa292aa51f2418f1c1c0d9f298b5e82 (diff) |
Add new input processing to autojags
-rw-r--r-- | R/autojags.R | 49 |
1 files changed, 31 insertions, 18 deletions
diff --git a/R/autojags.R b/R/autojags.R index 15b0e9d..4c67f43 100644 --- a/R/autojags.R +++ b/R/autojags.R @@ -3,14 +3,29 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad save.all.iter=FALSE,modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL, bugs.format=FALSE,Rhat.limit=1.1,max.iter=100000,verbose=TRUE){ - #Pass input data and parameter list through error check / processing - data.check <- process.input(data,parameters.to.save,inits,n.chains,(n.burnin + iter.increment), - n.burnin,n.thin,n.cores,DIC=DIC,autojags=TRUE,max.iter=max.iter, - verbose=verbose,parallel=parallel,seed=seed) - data <- data.check$data - parameters.to.save <- data.check$params - inits <- data.check$inits - if(parallel){n.cores <- data.check$n.cores} + if(!is.null(seed)){ + stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) + } + if(n.chains<2) stop('Number of chains must be >1 to calculate Rhat.') + if((max.iter < n.burnin) & verbose){ + old_warn <- options()$warn + options(warn=1) + warning('Maximum iterations includes burn-in and should be larger than burn-in.') + options(warn=old_warn) + } + + # Check input data + inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, + n_chains=n.chains, n_adapt=n.adapt, + n_iter=(n.burnin + iter.increment), + n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, + DIC=DIC, quiet=!verbose, parallel=parallel) + data <- inps_check$data + parameters.to.save <- inps_check$params + inits <- inps_check$inits + mcmc.info <- inps_check$mcmc.info + mcmc.info$end.values <- NULL # this is not saved in autojags for some reason + if(parallel) n.cores <- inps_check$mcmc.info$n.cores #Save start time start.time <- Sys.time() @@ -50,10 +65,11 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad } #Combine mcmc info into list - n.samples <- dim(samples[[1]])[1] * n.chains - mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt=sufficient.adapt,n.iter=(n.burnin + iter.increment),n.burnin,n.thin,n.samples,time) - names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','elapsed.mins') - if(parallel){mcmc.info$n.cores <- n.cores} + mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) + mcmc.info$n.samples <- coda::niter(samples) * n.chains + mcmc.info$n.adapt <- total.adapt + mcmc.info$sufficient.adapt <- sufficient.adapt + mcmc.info$n.iter <- n.burnin + iter.increment test <- test.Rhat(samples,Rhat.limit,codaOnly,verbose=verbose) reach.max <- FALSE @@ -113,7 +129,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad if(!save.all.iter){mcmc.info$n.burnin <- mcmc.info$n.iter} mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment - mcmc.info$n.samples <- dim(samples[[1]])[1] * n.chains + mcmc.info$n.samples <- coda::niter(samples) * n.chains mcmc.info$sufficient.adapt <- sufficient.adapt if(mcmc.info$n.iter>=max.iter){ @@ -123,9 +139,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad } #Get more info about MCMC run - end.time <- Sys.time() - mcmc.info$elapsed.mins <- round(as.numeric(end.time-start.time,units="mins"),digits=3) - date <- start.time + mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) #Reorganize JAGS output to match input parameter order samples <- order_samples(samples, parameters.to.save) @@ -152,8 +166,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad output$model <- mod output$parameters <- parameters.to.save output$mcmc.info <- mcmc.info - output$run.date <- date - output$random.seed <- seed + output$run.date <- start.time output$parallel <- parallel output$bugs.format <- bugs.format output$calc.DIC <- DIC |