aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2023-12-07 17:07:27 -0500
committerKen Kellner <ken@kenkellner.com>2023-12-07 17:12:22 -0500
commit5029ea9a65d9b785810b6027dc19fa2c4baf1529 (patch)
treed4b7f8073ba6efc7a4cdd1773a7f0978ffe5ac34
parentd7b75922602170e5240b9cfab12fd8cc11c8d23a (diff)
Clean up autojags function and use run_rjags
-rw-r--r--R/autojags.R174
1 files changed, 66 insertions, 108 deletions
diff --git a/R/autojags.R b/R/autojags.R
index 035f9c9..1cd90b4 100644
--- a/R/autojags.R
+++ b/R/autojags.R
@@ -1,7 +1,10 @@
-autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,iter.increment=1000,n.burnin=0,n.thin=1,
- save.all.iter=FALSE,modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL,
- bugs.format=FALSE,Rhat.limit=1.1,max.iter=100000,verbose=TRUE){
+autojags <- function(data, inits=NULL, parameters.to.save, model.file,
+ n.chains, n.adapt=NULL, iter.increment=1000, n.burnin=0, n.thin=1,
+ save.all.iter=FALSE, modules=c('glm'), factories=NULL,
+ parallel=FALSE, n.cores=NULL, DIC=TRUE, store.data=FALSE,
+ codaOnly=FALSE, seed=NULL, bugs.format=FALSE, Rhat.limit=1.1,
+ max.iter=100000, verbose=TRUE){
if(!is.null(seed)){
stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
@@ -15,141 +18,103 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
}
# Check input data
- inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
+ inps <- process_input(data=data, params=parameters.to.save, inits=inits,
n_chains=n.chains, n_adapt=n.adapt,
n_iter=(n.burnin + iter.increment),
n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
DIC=DIC, quiet=!verbose, parallel=parallel)
- data <- inps_check$data
- parameters.to.save <- inps_check$params
- inits <- inps_check$inits
- mcmc.info <- inps_check$mcmc.info
+ mcmc.info <- inps$mcmc.info
mcmc.info$end.values <- NULL # this is not saved in autojags for some reason
- if(parallel) n.cores <- inps_check$mcmc.info$n.cores
+ #Note if saving all iterations
+ if(save.all.iter&verbose){
+ cat('Note: ALL iterations will be included in final posterior.\n\n')
+ }
+
#Save start time
start.time <- Sys.time()
- #Note if saving all iterations
- if(save.all.iter&verbose){cat('Note: ALL iterations will be included in final posterior.\n\n')}
-
- #Initial model run
-
- #Parallel
-
- if(verbose){cat('Burn-in + Update 1',' (',(n.burnin + iter.increment),')',sep="")}
-
- if(parallel){
-
- par <- run.parallel(data,inits,parameters.to.save,model.file,n.chains,n.adapt,n.iter=(n.burnin + iter.increment),n.burnin,n.thin,
- modules=modules,factories=factories,DIC=DIC,verbose=FALSE,n.cores=n.cores)
- samples <- par$samples
- mod <- par$model
- total.adapt <- par$total.adapt
- sufficient.adapt <- par$sufficient.adapt
-
- } else {
-
- #Not parallel
-
- set.modules(modules,DIC)
- set.factories(factories)
-
- rjags.output <- run.model(model.file,data,inits,parameters.to.save,n.chains,n.iter=(n.burnin + iter.increment),n.burnin,n.thin,n.adapt,
- verbose=FALSE)
- samples <- rjags.output$samples
- mod <- rjags.output$m
- total.adapt <- rjags.output$total.adapt
- sufficient.adapt <- rjags.output$sufficient.adapt
-
+ if(verbose){
+ cat('Burn-in + Update 1',' (',(n.burnin + iter.increment),')',sep="")
}
- #Combine mcmc info into list
- mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
+ #Initial model run
+ mcmc.info$n.iter <- n.burnin + iter.increment
+ rjags_out <- run_rjags(inps$data, inps$inits, inps$params, model.file,
+ mcmc.info, modules, factories, DIC, parallel, quiet=TRUE)
+ # Save output
+ samples <- rjags_out$samples
+ mod <- rjags_out$m
+
+ #Update mcmc info
+ mcmc.info$elapsed.mins <- rjags_out$elapsed.mins
mcmc.info$n.samples <- coda::niter(samples) * n.chains
- mcmc.info$n.adapt <- total.adapt
- mcmc.info$sufficient.adapt <- sufficient.adapt
+ mcmc.info$n.adapt <- rjags_out$total.adapt
+ mcmc.info$sufficient.adapt <- rjags_out$sufficient.adapt
mcmc.info$n.iter <- n.burnin + iter.increment
- test <- test.Rhat(samples,Rhat.limit,codaOnly,verbose=verbose)
+ # Tests to see if function should stop
+ large_Rhats <- test.Rhat(samples, Rhat.limit, codaOnly, verbose=verbose)
reach.max <- FALSE
- index = 1
+ index <- 1
if(mcmc.info$n.iter>=max.iter){
reach.max <- TRUE
if(verbose){cat('\nMaximum iterations reached.\n\n')}
}
- while(test==TRUE && reach.max==FALSE){
+ # Continue incremental running
+ while(large_Rhats & !reach.max){
index <- index + 1
- if(verbose){cat('Update ',index,' (',mcmc.info$n.iter + iter.increment,')',sep="")}
+ if(verbose){
+ cat('Update ',index,' (',mcmc.info$n.iter + iter.increment,')',sep="")
+ }
+
+ # MCMC info for just this update
+ mcmc_info_update <- mcmc.info
+ mcmc_info_update$n.adapt <- n.adapt
+ mcmc_info_update$n.iter <- iter.increment
+ mcmc_info_update$n.burnin <- 0
+ rjags_out <- run_rjags(data=NULL, inits=NULL, inps$params, modfile=NULL,
+ mcmc_info_update, modules, factories, DIC, parallel,
+ quiet=TRUE, model.object = mod, update=TRUE)
+ # Save the model object
+ mod <- rjags_out$m
+
+ # Save samples and combine with previous samples if required
if(save.all.iter){
- if(index==2){start.iter <- stats::start(samples)}
- if (index > 1) {
- old.samples <- samples
- }
- }
-
- if(parallel){
-
- par <- run.parallel(data=NULL,inits=NULL,parameters.to.save=parameters.to.save,model.file=NULL,n.chains=n.chains
- ,n.adapt=n.adapt,n.iter=iter.increment,n.burnin=0,n.thin=n.thin,modules=modules,
- factories=factories,DIC=DIC,model.object=mod,update=TRUE,verbose=FALSE,n.cores=n.cores)
-
- if(save.all.iter & index > 1){
- samples <- bind.mcmc(old.samples,par$samples,start=start.iter,n.new.iter=iter.increment)
- } else {samples <- par$samples}
-
- mod <- par$model
- sufficient.adapt <- par$sufficient.adapt
-
- test <- test.Rhat(samples,Rhat.limit,codaOnly)
-
+ samples <- bind.mcmc(samples,rjags_out$samples, start=stats::start(samples),
+ n.new.iter=iter.increment)
} else {
-
- set.modules(modules,DIC)
-
- rjags.output <- run.model(model.file=NULL,data=NULL,inits=NULL,parameters.to.save=parameters.to.save,
- n.chains=n.chains,n.iter=iter.increment,n.burnin=0,n.thin,n.adapt=n.adapt,
- model.object=mod,update=TRUE,verbose=FALSE)
-
- if(save.all.iter & index > 1){
- samples <- bind.mcmc(old.samples,rjags.output$samples,start=start.iter,n.new.iter=iter.increment)
- } else {samples <- rjags.output$samples}
-
- mod <- rjags.output$m
- sufficient.adapt <- rjags.output$sufficient.adapt
-
- test <- test.Rhat(samples,Rhat.limit,codaOnly)
-
-
+ samples <- rjags_out$samples
}
-
- if(!save.all.iter){mcmc.info$n.burnin <- mcmc.info$n.iter}
- mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment
+
+ # Update the total iteration count etc.
+ if(!save.all.iter) mcmc.info$n.burnin <- mcmc.info$n.iter
+ mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment
mcmc.info$n.samples <- coda::niter(samples) * n.chains
- mcmc.info$sufficient.adapt <- sufficient.adapt
-
+ mcmc.info$sufficient.adapt <- rjags_out$sufficient.adapt
+
+ # Test to see if JAGS should continue updating model
+ large_Rhats <- test.Rhat(samples, Rhat.limit, codaOnly)
if(mcmc.info$n.iter>=max.iter){
reach.max <- TRUE
- if(verbose){cat('\nMaximum iterations reached.\n\n')}
+ if(verbose) cat('\nMaximum iterations reached.\n\n')
}
}
- #Get more info about MCMC run
+ #Save final runtime
mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
#Reorganize JAGS output to match input parameter order
- samples <- order_samples(samples, parameters.to.save)
+ samples <- order_samples(samples, inps$params)
#Convert rjags output to jagsUI form
output <- process_output(samples, coda_only = codaOnly, DIC, quiet = !verbose)
if(is.null(output)){
- output <- list()
- output$samples <- samples
- output$model <- mod
+ output <- list(samples = samples, model = mod)
output$n.cores <- n.cores
class(output) <- 'jagsUIbasic'
return(output)
@@ -164,7 +129,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
output$data <- data
}
output$model <- mod
- output$parameters <- parameters.to.save
+ output$parameters <- inps$params
output$mcmc.info <- mcmc.info
output$run.date <- start.time
output$parallel <- parallel
@@ -175,17 +140,10 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
class(output) <- 'jagsUI'
return(output)
-
-
-
-
-
-
-
-
}
-
+
+# Function to test if all Rhats are below some cutoff value--------------------
test.Rhat <- function(samples,cutoff,params.omit,verbose=TRUE){
params <- colnames(samples[[1]])