diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-12-06 16:33:59 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-12-06 16:33:59 -0500 |
commit | 030c81cf866c40b7bfa9bafbfbb1d9f2c44d73c5 (patch) | |
tree | 2d517a4d91a4d121ae6402106e92d6095541fb25 | |
parent | 4c0c1a37780afbcbb458a279d2fcfad188817dcf (diff) |
Reorganize files
-rw-r--r-- | R/S3_methods.R (renamed from R/print.R) | 28 | ||||
-rw-r--r-- | R/autojags.R | 38 | ||||
-rw-r--r-- | R/bindmcmc.R | 19 | ||||
-rw-r--r-- | R/mcmc_tools.R | 22 | ||||
-rw-r--r-- | R/plot.R | 23 | ||||
-rw-r--r-- | R/plot_tools.R (renamed from R/get_plot_info.R) | 16 | ||||
-rw-r--r-- | R/summary.R | 3 | ||||
-rw-r--r-- | R/testrhat.R | 38 | ||||
-rw-r--r-- | R/update.R | 50 | ||||
-rw-r--r-- | R/updatebasic.R | 47 | ||||
-rw-r--r-- | R/utils.R | 8 |
11 files changed, 147 insertions, 145 deletions
diff --git a/R/print.R b/R/S3_methods.R index bb593ec..6df687a 100644 --- a/R/print.R +++ b/R/S3_methods.R @@ -1,4 +1,32 @@ +# Summary method +summary.jagsUI <- function(object, ...){ + object$summary +} + +#Plot method +plot.jagsUI <- function(x, parameters=NULL, per_plot=4, ask=NULL, ...){ + + if(is.null(ask)) + ask <- grDevices::dev.interactive(orNone = TRUE) + plot_info <- get_plot_info(x, parameters, NULL, ask) + dims <- c(min(length(plot_info$params), per_plot), 2) + if(length(plot_info$params) <= per_plot) + ask <- FALSE + new_par <- list(mfrow = dims, mar = c(4,4,2.5,1), oma=c(0,0,0,0), ask=ask) + + #Handle par() + old_par <- graphics::par(new_par) + on.exit(graphics::par(old_par)) + + + #Make plot + for (i in plot_info$params){ + param_trace(x, i) + param_density(x, i) + } +} +# Print method print.jagsUI <- function(x,digits=3,...){ mc <- x$mcmc.info diff --git a/R/autojags.R b/R/autojags.R index 4c67f43..3b1f809 100644 --- a/R/autojags.R +++ b/R/autojags.R @@ -184,4 +184,42 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad } + + +test.Rhat <- function(samples,cutoff,params.omit,verbose=TRUE){ + + params <- colnames(samples[[1]]) + expand <- sapply(strsplit(params, "\\["), "[", 1) + + gd <- function(hold){ + r <- try(gelman.diag(hold, autoburnin=FALSE)$psrf[1], silent=TRUE) + if(inherits(r, "try-error") || !is.finite(r)) { + r <- NA + } + return(r) + } + + failure <- FALSE + index <- 1 + while (failure==FALSE && index <= length(params)){ + + if(!expand[index]%in%params.omit){ + test <- gd(samples[,index]) + } else {test <- 1} + + if(is.na(test)){test <- 1} + + if(test>cutoff){failure=TRUE + } else {index <- index + 1} + } + if(failure==TRUE&verbose){ + cat('.......Convergence check failed for parameter \'',params[index],'\'\n',sep="") + } + if(failure==FALSE&verbose){ + cat('.......All parameters converged.','\n\n') + } + + return(failure) + +} diff --git a/R/bindmcmc.R b/R/bindmcmc.R deleted file mode 100644 index 451e6a8..0000000 --- a/R/bindmcmc.R +++ /dev/null @@ -1,19 +0,0 @@ - -bind.mcmc <- function(mcmc.list1,mcmc.list2,start,n.new.iter){ - - nchains <- length(mcmc.list1) - - samples <- list() - - for (i in 1:nchains){ - - d <- rbind(mcmc.list1[[i]],mcmc.list2[[i]]) - - samples[[i]] <- mcmc(data=d,start=start,end=(end(mcmc.list1[[i]])+n.new.iter),thin=thin(mcmc.list1[i])) - - } - - return(as.mcmc.list(samples)) - - -}
\ No newline at end of file diff --git a/R/mcmc_tools.R b/R/mcmc_tools.R index 612fda5..1c867a1 100644 --- a/R/mcmc_tools.R +++ b/R/mcmc_tools.R @@ -80,3 +80,25 @@ get_inds <- function(param, params_raw){ has_one_parameter <- function(mcmc_list){ coda::nvar(mcmc_list) == 1 } + + +#------------------------------------------------------------------------------ +# Bind two mcmc.lists together +bind.mcmc <- function(mcmc.list1,mcmc.list2,start,n.new.iter){ + + nchains <- length(mcmc.list1) + + samples <- list() + + for (i in 1:nchains){ + + d <- rbind(mcmc.list1[[i]],mcmc.list2[[i]]) + + samples[[i]] <- mcmc(data=d,start=start,end=(end(mcmc.list1[[i]])+n.new.iter),thin=thin(mcmc.list1[i])) + + } + + return(as.mcmc.list(samples)) + + +} diff --git a/R/plot.R b/R/plot.R deleted file mode 100644 index 17108cf..0000000 --- a/R/plot.R +++ /dev/null @@ -1,23 +0,0 @@ - -#Plot method for jagsUI objects -plot.jagsUI <- function(x, parameters=NULL, per_plot=4, ask=NULL, ...){ - - if(is.null(ask)) - ask <- grDevices::dev.interactive(orNone = TRUE) - plot_info <- get_plot_info(x, parameters, NULL, ask) - dims <- c(min(length(plot_info$params), per_plot), 2) - if(length(plot_info$params) <= per_plot) - ask <- FALSE - new_par <- list(mfrow = dims, mar = c(4,4,2.5,1), oma=c(0,0,0,0), ask=ask) - - #Handle par() - old_par <- graphics::par(new_par) - on.exit(graphics::par(old_par)) - - - #Make plot - for (i in plot_info$params){ - param_trace(x, i) - param_density(x, i) - } -} diff --git a/R/get_plot_info.R b/R/plot_tools.R index ac0ae42..7b58b91 100644 --- a/R/get_plot_info.R +++ b/R/plot_tools.R @@ -1,5 +1,9 @@ +#Check that an object is the right class--------------------------------------- +check_class <- function(output){ + if(!inherits(output, "jagsUI")) stop("Requires jagsUI object") +} -#General function for setting up plots +#General function for setting up plots----------------------------------------- # Called by densityplot, traceplot, and plot.jagsUI # plot.jagsUI only uses the 'params' component in the output, ignores the rest get_plot_info <- function(x, parameters, layout, ask, Rhat_min=NULL){ @@ -50,9 +54,9 @@ get_plot_info <- function(x, parameters, layout, ask, Rhat_min=NULL){ list(params=parameters, new_par=new_par, per_plot=per_plot) } - -has_brackets <- function(x){ - grepl("\\[.*\\]", x) +# Parameter name tools--------------------------------------------------------- +expand_params <- function(params){ + unlist(lapply(params, expand_brackets)) } expand_brackets <- function(x){ @@ -64,6 +68,6 @@ expand_brackets <- function(x){ paste0(pname, "[",rng,"]") } -expand_params <- function(params){ - unlist(lapply(params, expand_brackets)) +has_brackets <- function(x){ + grepl("\\[.*\\]", x) } diff --git a/R/summary.R b/R/summary.R deleted file mode 100644 index 4cdeffd..0000000 --- a/R/summary.R +++ /dev/null @@ -1,3 +0,0 @@ -summary.jagsUI <- function(object, ...){ - object$summary -} diff --git a/R/testrhat.R b/R/testrhat.R deleted file mode 100644 index 5c85448..0000000 --- a/R/testrhat.R +++ /dev/null @@ -1,38 +0,0 @@ - -test.Rhat <- function(samples,cutoff,params.omit,verbose=TRUE){ - - params <- colnames(samples[[1]]) - expand <- sapply(strsplit(params, "\\["), "[", 1) - - gd <- function(hold){ - r <- try(gelman.diag(hold, autoburnin=FALSE)$psrf[1], silent=TRUE) - if(inherits(r, "try-error") || !is.finite(r)) { - r <- NA - } - return(r) - } - - failure <- FALSE - index <- 1 - while (failure==FALSE && index <= length(params)){ - - if(!expand[index]%in%params.omit){ - test <- gd(samples[,index]) - } else {test <- 1} - - if(is.na(test)){test <- 1} - - if(test>cutoff){failure=TRUE - } else {index <- index + 1} - } - - if(failure==TRUE&verbose){ - cat('.......Convergence check failed for parameter \'',params[index],'\'\n',sep="") - } - if(failure==FALSE&verbose){ - cat('.......All parameters converged.','\n\n') - } - - return(failure) - -}
\ No newline at end of file @@ -1,4 +1,4 @@ - +# update method for jagsUI class----------------------------------------------- update.jagsUI <- function(object, parameters.to.save=NULL, n.adapt=NULL, n.iter, n.thin=NULL, modules=c('glm'), factories=NULL, @@ -72,3 +72,51 @@ update.jagsUI <- function(object, parameters.to.save=NULL, return(output) } + +# update method for jagsUIbasic class------------------------------------------ +update.jagsUIbasic <- function(object, parameters.to.save=NULL, + n.adapt=NULL, n.iter, n.thin=NULL, + modules=c('glm'), factories=NULL, + DIC=NULL, verbose=TRUE, ...){ + + # Set up parameters + if(is.null(parameters.to.save)){ + params_long <- colnames(object$samples[[1]]) + parameters.to.save <- unique(sapply(strsplit(params_long, "\\["), "[", 1)) + } + + #Set up DIC monitoring + if(is.null(DIC)){ + DIC <- 'deviance' %in% parameters.to.save + } else { + if(DIC & (!'deviance' %in% parameters.to.save)){ + parameters.to.save <- c(parameters.to.save, 'deviance') + } else if(!DIC & 'deviance' %in% parameters.to.save){ + parameters.to.save <- parameters.to.save[parameters.to.save != 'deviance'] + } + } + + # Set up MCMC info + mcmc.info <- list(n.chains = length(object$samples), n.adapt = n.adapt, + n.iter = n.iter, n.burnin = 0, + n.thin = ifelse(is.null(n.thin), thin(object$samples), n.thin), + n.cores = object$n.cores) + + parallel <- names(object$model[1]) == "cluster1" + + # Run JAGS via rjags + rjags_out <- run_rjags(data=NULL, inits=NULL, parameters.to.save, modfile=NULL, + mcmc.info, modules, factories, DIC, parallel, !verbose, + model.object = object$model, update=TRUE) + + # Report time + if(verbose) cat('MCMC took', rjags_out$elapsed.min, 'minutes.\n') + + # Create output object + output <- list(samples = order_samples(rjags_out$samples, parameters.to.save), + model = rjags_out$m, + n.cores = object$n.cores) + class(output) <- 'jagsUIbasic' + + return(output) +} diff --git a/R/updatebasic.R b/R/updatebasic.R deleted file mode 100644 index 97fd512..0000000 --- a/R/updatebasic.R +++ /dev/null @@ -1,47 +0,0 @@ - -update.jagsUIbasic <- function(object, parameters.to.save=NULL, - n.adapt=NULL, n.iter, n.thin=NULL, - modules=c('glm'), factories=NULL, - DIC=NULL, verbose=TRUE, ...){ - - # Set up parameters - if(is.null(parameters.to.save)){ - params_long <- colnames(object$samples[[1]]) - parameters.to.save <- unique(sapply(strsplit(params_long, "\\["), "[", 1)) - } - - #Set up DIC monitoring - if(is.null(DIC)){ - DIC <- 'deviance' %in% parameters.to.save - } else { - if(DIC & (!'deviance' %in% parameters.to.save)){ - parameters.to.save <- c(parameters.to.save, 'deviance') - } else if(!DIC & 'deviance' %in% parameters.to.save){ - parameters.to.save <- parameters.to.save[parameters.to.save != 'deviance'] - } - } - - # Set up MCMC info - mcmc.info <- list(n.chains = length(object$samples), n.adapt = n.adapt, - n.iter = n.iter, n.burnin = 0, - n.thin = ifelse(is.null(n.thin), thin(object$samples), n.thin), - n.cores = object$n.cores) - - parallel <- names(object$model[1]) == "cluster1" - - # Run JAGS via rjags - rjags_out <- run_rjags(data=NULL, inits=NULL, parameters.to.save, modfile=NULL, - mcmc.info, modules, factories, DIC, parallel, !verbose, - model.object = object$model, update=TRUE) - - # Report time - if(verbose) cat('MCMC took', rjags_out$elapsed.min, 'minutes.\n') - - # Create output object - output <- list(samples = order_samples(rjags_out$samples, parameters.to.save), - model = rjags_out$m, - n.cores = object$n.cores) - class(output) <- 'jagsUIbasic' - - return(output) -} diff --git a/R/utils.R b/R/utils.R deleted file mode 100644 index 9458c63..0000000 --- a/R/utils.R +++ /dev/null @@ -1,8 +0,0 @@ - - -#--- from process_output --------------------------------------------------------------------------- -#Check that an object is the right class -check_class <- function(output){ - if(!inherits(output, "jagsUI")) stop("Requires jagsUI object") -} -#------------------------------------------------------------------------------ |