diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-12-07 17:07:27 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-12-07 17:12:22 -0500 |
commit | 5029ea9a65d9b785810b6027dc19fa2c4baf1529 (patch) | |
tree | d4b7f8073ba6efc7a4cdd1773a7f0978ffe5ac34 | |
parent | d7b75922602170e5240b9cfab12fd8cc11c8d23a (diff) |
Clean up autojags function and use run_rjags
-rw-r--r-- | R/autojags.R | 174 |
1 files changed, 66 insertions, 108 deletions
diff --git a/R/autojags.R b/R/autojags.R index 035f9c9..1cd90b4 100644 --- a/R/autojags.R +++ b/R/autojags.R @@ -1,7 +1,10 @@ -autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,iter.increment=1000,n.burnin=0,n.thin=1, - save.all.iter=FALSE,modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL, - bugs.format=FALSE,Rhat.limit=1.1,max.iter=100000,verbose=TRUE){ +autojags <- function(data, inits=NULL, parameters.to.save, model.file, + n.chains, n.adapt=NULL, iter.increment=1000, n.burnin=0, n.thin=1, + save.all.iter=FALSE, modules=c('glm'), factories=NULL, + parallel=FALSE, n.cores=NULL, DIC=TRUE, store.data=FALSE, + codaOnly=FALSE, seed=NULL, bugs.format=FALSE, Rhat.limit=1.1, + max.iter=100000, verbose=TRUE){ if(!is.null(seed)){ stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) @@ -15,141 +18,103 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad } # Check input data - inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, + inps <- process_input(data=data, params=parameters.to.save, inits=inits, n_chains=n.chains, n_adapt=n.adapt, n_iter=(n.burnin + iter.increment), n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, DIC=DIC, quiet=!verbose, parallel=parallel) - data <- inps_check$data - parameters.to.save <- inps_check$params - inits <- inps_check$inits - mcmc.info <- inps_check$mcmc.info + mcmc.info <- inps$mcmc.info mcmc.info$end.values <- NULL # this is not saved in autojags for some reason - if(parallel) n.cores <- inps_check$mcmc.info$n.cores + #Note if saving all iterations + if(save.all.iter&verbose){ + cat('Note: ALL iterations will be included in final posterior.\n\n') + } + #Save start time start.time <- Sys.time() - #Note if saving all iterations - if(save.all.iter&verbose){cat('Note: ALL iterations will be included in final posterior.\n\n')} - - #Initial model run - - #Parallel - - if(verbose){cat('Burn-in + Update 1',' (',(n.burnin + iter.increment),')',sep="")} - - if(parallel){ - - par <- run.parallel(data,inits,parameters.to.save,model.file,n.chains,n.adapt,n.iter=(n.burnin + iter.increment),n.burnin,n.thin, - modules=modules,factories=factories,DIC=DIC,verbose=FALSE,n.cores=n.cores) - samples <- par$samples - mod <- par$model - total.adapt <- par$total.adapt - sufficient.adapt <- par$sufficient.adapt - - } else { - - #Not parallel - - set.modules(modules,DIC) - set.factories(factories) - - rjags.output <- run.model(model.file,data,inits,parameters.to.save,n.chains,n.iter=(n.burnin + iter.increment),n.burnin,n.thin,n.adapt, - verbose=FALSE) - samples <- rjags.output$samples - mod <- rjags.output$m - total.adapt <- rjags.output$total.adapt - sufficient.adapt <- rjags.output$sufficient.adapt - + if(verbose){ + cat('Burn-in + Update 1',' (',(n.burnin + iter.increment),')',sep="") } - #Combine mcmc info into list - mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) + #Initial model run + mcmc.info$n.iter <- n.burnin + iter.increment + rjags_out <- run_rjags(inps$data, inps$inits, inps$params, model.file, + mcmc.info, modules, factories, DIC, parallel, quiet=TRUE) + # Save output + samples <- rjags_out$samples + mod <- rjags_out$m + + #Update mcmc info + mcmc.info$elapsed.mins <- rjags_out$elapsed.mins mcmc.info$n.samples <- coda::niter(samples) * n.chains - mcmc.info$n.adapt <- total.adapt - mcmc.info$sufficient.adapt <- sufficient.adapt + mcmc.info$n.adapt <- rjags_out$total.adapt + mcmc.info$sufficient.adapt <- rjags_out$sufficient.adapt mcmc.info$n.iter <- n.burnin + iter.increment - test <- test.Rhat(samples,Rhat.limit,codaOnly,verbose=verbose) + # Tests to see if function should stop + large_Rhats <- test.Rhat(samples, Rhat.limit, codaOnly, verbose=verbose) reach.max <- FALSE - index = 1 + index <- 1 if(mcmc.info$n.iter>=max.iter){ reach.max <- TRUE if(verbose){cat('\nMaximum iterations reached.\n\n')} } - while(test==TRUE && reach.max==FALSE){ + # Continue incremental running + while(large_Rhats & !reach.max){ index <- index + 1 - if(verbose){cat('Update ',index,' (',mcmc.info$n.iter + iter.increment,')',sep="")} + if(verbose){ + cat('Update ',index,' (',mcmc.info$n.iter + iter.increment,')',sep="") + } + + # MCMC info for just this update + mcmc_info_update <- mcmc.info + mcmc_info_update$n.adapt <- n.adapt + mcmc_info_update$n.iter <- iter.increment + mcmc_info_update$n.burnin <- 0 + rjags_out <- run_rjags(data=NULL, inits=NULL, inps$params, modfile=NULL, + mcmc_info_update, modules, factories, DIC, parallel, + quiet=TRUE, model.object = mod, update=TRUE) + # Save the model object + mod <- rjags_out$m + + # Save samples and combine with previous samples if required if(save.all.iter){ - if(index==2){start.iter <- stats::start(samples)} - if (index > 1) { - old.samples <- samples - } - } - - if(parallel){ - - par <- run.parallel(data=NULL,inits=NULL,parameters.to.save=parameters.to.save,model.file=NULL,n.chains=n.chains - ,n.adapt=n.adapt,n.iter=iter.increment,n.burnin=0,n.thin=n.thin,modules=modules, - factories=factories,DIC=DIC,model.object=mod,update=TRUE,verbose=FALSE,n.cores=n.cores) - - if(save.all.iter & index > 1){ - samples <- bind.mcmc(old.samples,par$samples,start=start.iter,n.new.iter=iter.increment) - } else {samples <- par$samples} - - mod <- par$model - sufficient.adapt <- par$sufficient.adapt - - test <- test.Rhat(samples,Rhat.limit,codaOnly) - + samples <- bind.mcmc(samples,rjags_out$samples, start=stats::start(samples), + n.new.iter=iter.increment) } else { - - set.modules(modules,DIC) - - rjags.output <- run.model(model.file=NULL,data=NULL,inits=NULL,parameters.to.save=parameters.to.save, - n.chains=n.chains,n.iter=iter.increment,n.burnin=0,n.thin,n.adapt=n.adapt, - model.object=mod,update=TRUE,verbose=FALSE) - - if(save.all.iter & index > 1){ - samples <- bind.mcmc(old.samples,rjags.output$samples,start=start.iter,n.new.iter=iter.increment) - } else {samples <- rjags.output$samples} - - mod <- rjags.output$m - sufficient.adapt <- rjags.output$sufficient.adapt - - test <- test.Rhat(samples,Rhat.limit,codaOnly) - - + samples <- rjags_out$samples } - - if(!save.all.iter){mcmc.info$n.burnin <- mcmc.info$n.iter} - mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment + + # Update the total iteration count etc. + if(!save.all.iter) mcmc.info$n.burnin <- mcmc.info$n.iter + mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment mcmc.info$n.samples <- coda::niter(samples) * n.chains - mcmc.info$sufficient.adapt <- sufficient.adapt - + mcmc.info$sufficient.adapt <- rjags_out$sufficient.adapt + + # Test to see if JAGS should continue updating model + large_Rhats <- test.Rhat(samples, Rhat.limit, codaOnly) if(mcmc.info$n.iter>=max.iter){ reach.max <- TRUE - if(verbose){cat('\nMaximum iterations reached.\n\n')} + if(verbose) cat('\nMaximum iterations reached.\n\n') } } - #Get more info about MCMC run + #Save final runtime mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) #Reorganize JAGS output to match input parameter order - samples <- order_samples(samples, parameters.to.save) + samples <- order_samples(samples, inps$params) #Convert rjags output to jagsUI form output <- process_output(samples, coda_only = codaOnly, DIC, quiet = !verbose) if(is.null(output)){ - output <- list() - output$samples <- samples - output$model <- mod + output <- list(samples = samples, model = mod) output$n.cores <- n.cores class(output) <- 'jagsUIbasic' return(output) @@ -164,7 +129,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad output$data <- data } output$model <- mod - output$parameters <- parameters.to.save + output$parameters <- inps$params output$mcmc.info <- mcmc.info output$run.date <- start.time output$parallel <- parallel @@ -175,17 +140,10 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad class(output) <- 'jagsUI' return(output) - - - - - - - - } - + +# Function to test if all Rhats are below some cutoff value-------------------- test.Rhat <- function(samples,cutoff,params.omit,verbose=TRUE){ params <- colnames(samples[[1]]) |