aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2023-12-06 16:33:59 -0500
committerKen Kellner <ken@kenkellner.com>2023-12-06 16:33:59 -0500
commit030c81cf866c40b7bfa9bafbfbb1d9f2c44d73c5 (patch)
tree2d517a4d91a4d121ae6402106e92d6095541fb25
parent4c0c1a37780afbcbb458a279d2fcfad188817dcf (diff)
Reorganize files
-rw-r--r--R/S3_methods.R (renamed from R/print.R)28
-rw-r--r--R/autojags.R38
-rw-r--r--R/bindmcmc.R19
-rw-r--r--R/mcmc_tools.R22
-rw-r--r--R/plot.R23
-rw-r--r--R/plot_tools.R (renamed from R/get_plot_info.R)16
-rw-r--r--R/summary.R3
-rw-r--r--R/testrhat.R38
-rw-r--r--R/update.R50
-rw-r--r--R/updatebasic.R47
-rw-r--r--R/utils.R8
11 files changed, 147 insertions, 145 deletions
diff --git a/R/print.R b/R/S3_methods.R
index bb593ec..6df687a 100644
--- a/R/print.R
+++ b/R/S3_methods.R
@@ -1,4 +1,32 @@
+# Summary method
+summary.jagsUI <- function(object, ...){
+ object$summary
+}
+
+#Plot method
+plot.jagsUI <- function(x, parameters=NULL, per_plot=4, ask=NULL, ...){
+
+ if(is.null(ask))
+ ask <- grDevices::dev.interactive(orNone = TRUE)
+ plot_info <- get_plot_info(x, parameters, NULL, ask)
+ dims <- c(min(length(plot_info$params), per_plot), 2)
+ if(length(plot_info$params) <= per_plot)
+ ask <- FALSE
+ new_par <- list(mfrow = dims, mar = c(4,4,2.5,1), oma=c(0,0,0,0), ask=ask)
+
+ #Handle par()
+ old_par <- graphics::par(new_par)
+ on.exit(graphics::par(old_par))
+
+
+ #Make plot
+ for (i in plot_info$params){
+ param_trace(x, i)
+ param_density(x, i)
+ }
+}
+# Print method
print.jagsUI <- function(x,digits=3,...){
mc <- x$mcmc.info
diff --git a/R/autojags.R b/R/autojags.R
index 4c67f43..3b1f809 100644
--- a/R/autojags.R
+++ b/R/autojags.R
@@ -184,4 +184,42 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
}
+
+
+test.Rhat <- function(samples,cutoff,params.omit,verbose=TRUE){
+
+ params <- colnames(samples[[1]])
+ expand <- sapply(strsplit(params, "\\["), "[", 1)
+
+ gd <- function(hold){
+ r <- try(gelman.diag(hold, autoburnin=FALSE)$psrf[1], silent=TRUE)
+ if(inherits(r, "try-error") || !is.finite(r)) {
+ r <- NA
+ }
+ return(r)
+ }
+
+ failure <- FALSE
+ index <- 1
+ while (failure==FALSE && index <= length(params)){
+
+ if(!expand[index]%in%params.omit){
+ test <- gd(samples[,index])
+ } else {test <- 1}
+
+ if(is.na(test)){test <- 1}
+
+ if(test>cutoff){failure=TRUE
+ } else {index <- index + 1}
+ }
+ if(failure==TRUE&verbose){
+ cat('.......Convergence check failed for parameter \'',params[index],'\'\n',sep="")
+ }
+ if(failure==FALSE&verbose){
+ cat('.......All parameters converged.','\n\n')
+ }
+
+ return(failure)
+
+}
diff --git a/R/bindmcmc.R b/R/bindmcmc.R
deleted file mode 100644
index 451e6a8..0000000
--- a/R/bindmcmc.R
+++ /dev/null
@@ -1,19 +0,0 @@
-
-bind.mcmc <- function(mcmc.list1,mcmc.list2,start,n.new.iter){
-
- nchains <- length(mcmc.list1)
-
- samples <- list()
-
- for (i in 1:nchains){
-
- d <- rbind(mcmc.list1[[i]],mcmc.list2[[i]])
-
- samples[[i]] <- mcmc(data=d,start=start,end=(end(mcmc.list1[[i]])+n.new.iter),thin=thin(mcmc.list1[i]))
-
- }
-
- return(as.mcmc.list(samples))
-
-
-} \ No newline at end of file
diff --git a/R/mcmc_tools.R b/R/mcmc_tools.R
index 612fda5..1c867a1 100644
--- a/R/mcmc_tools.R
+++ b/R/mcmc_tools.R
@@ -80,3 +80,25 @@ get_inds <- function(param, params_raw){
has_one_parameter <- function(mcmc_list){
coda::nvar(mcmc_list) == 1
}
+
+
+#------------------------------------------------------------------------------
+# Bind two mcmc.lists together
+bind.mcmc <- function(mcmc.list1,mcmc.list2,start,n.new.iter){
+
+ nchains <- length(mcmc.list1)
+
+ samples <- list()
+
+ for (i in 1:nchains){
+
+ d <- rbind(mcmc.list1[[i]],mcmc.list2[[i]])
+
+ samples[[i]] <- mcmc(data=d,start=start,end=(end(mcmc.list1[[i]])+n.new.iter),thin=thin(mcmc.list1[i]))
+
+ }
+
+ return(as.mcmc.list(samples))
+
+
+}
diff --git a/R/plot.R b/R/plot.R
deleted file mode 100644
index 17108cf..0000000
--- a/R/plot.R
+++ /dev/null
@@ -1,23 +0,0 @@
-
-#Plot method for jagsUI objects
-plot.jagsUI <- function(x, parameters=NULL, per_plot=4, ask=NULL, ...){
-
- if(is.null(ask))
- ask <- grDevices::dev.interactive(orNone = TRUE)
- plot_info <- get_plot_info(x, parameters, NULL, ask)
- dims <- c(min(length(plot_info$params), per_plot), 2)
- if(length(plot_info$params) <= per_plot)
- ask <- FALSE
- new_par <- list(mfrow = dims, mar = c(4,4,2.5,1), oma=c(0,0,0,0), ask=ask)
-
- #Handle par()
- old_par <- graphics::par(new_par)
- on.exit(graphics::par(old_par))
-
-
- #Make plot
- for (i in plot_info$params){
- param_trace(x, i)
- param_density(x, i)
- }
-}
diff --git a/R/get_plot_info.R b/R/plot_tools.R
index ac0ae42..7b58b91 100644
--- a/R/get_plot_info.R
+++ b/R/plot_tools.R
@@ -1,5 +1,9 @@
+#Check that an object is the right class---------------------------------------
+check_class <- function(output){
+ if(!inherits(output, "jagsUI")) stop("Requires jagsUI object")
+}
-#General function for setting up plots
+#General function for setting up plots-----------------------------------------
# Called by densityplot, traceplot, and plot.jagsUI
# plot.jagsUI only uses the 'params' component in the output, ignores the rest
get_plot_info <- function(x, parameters, layout, ask, Rhat_min=NULL){
@@ -50,9 +54,9 @@ get_plot_info <- function(x, parameters, layout, ask, Rhat_min=NULL){
list(params=parameters, new_par=new_par, per_plot=per_plot)
}
-
-has_brackets <- function(x){
- grepl("\\[.*\\]", x)
+# Parameter name tools---------------------------------------------------------
+expand_params <- function(params){
+ unlist(lapply(params, expand_brackets))
}
expand_brackets <- function(x){
@@ -64,6 +68,6 @@ expand_brackets <- function(x){
paste0(pname, "[",rng,"]")
}
-expand_params <- function(params){
- unlist(lapply(params, expand_brackets))
+has_brackets <- function(x){
+ grepl("\\[.*\\]", x)
}
diff --git a/R/summary.R b/R/summary.R
deleted file mode 100644
index 4cdeffd..0000000
--- a/R/summary.R
+++ /dev/null
@@ -1,3 +0,0 @@
-summary.jagsUI <- function(object, ...){
- object$summary
-}
diff --git a/R/testrhat.R b/R/testrhat.R
deleted file mode 100644
index 5c85448..0000000
--- a/R/testrhat.R
+++ /dev/null
@@ -1,38 +0,0 @@
-
-test.Rhat <- function(samples,cutoff,params.omit,verbose=TRUE){
-
- params <- colnames(samples[[1]])
- expand <- sapply(strsplit(params, "\\["), "[", 1)
-
- gd <- function(hold){
- r <- try(gelman.diag(hold, autoburnin=FALSE)$psrf[1], silent=TRUE)
- if(inherits(r, "try-error") || !is.finite(r)) {
- r <- NA
- }
- return(r)
- }
-
- failure <- FALSE
- index <- 1
- while (failure==FALSE && index <= length(params)){
-
- if(!expand[index]%in%params.omit){
- test <- gd(samples[,index])
- } else {test <- 1}
-
- if(is.na(test)){test <- 1}
-
- if(test>cutoff){failure=TRUE
- } else {index <- index + 1}
- }
-
- if(failure==TRUE&verbose){
- cat('.......Convergence check failed for parameter \'',params[index],'\'\n',sep="")
- }
- if(failure==FALSE&verbose){
- cat('.......All parameters converged.','\n\n')
- }
-
- return(failure)
-
-} \ No newline at end of file
diff --git a/R/update.R b/R/update.R
index b731851..048a0c3 100644
--- a/R/update.R
+++ b/R/update.R
@@ -1,4 +1,4 @@
-
+# update method for jagsUI class-----------------------------------------------
update.jagsUI <- function(object, parameters.to.save=NULL,
n.adapt=NULL, n.iter, n.thin=NULL,
modules=c('glm'), factories=NULL,
@@ -72,3 +72,51 @@ update.jagsUI <- function(object, parameters.to.save=NULL,
return(output)
}
+
+# update method for jagsUIbasic class------------------------------------------
+update.jagsUIbasic <- function(object, parameters.to.save=NULL,
+ n.adapt=NULL, n.iter, n.thin=NULL,
+ modules=c('glm'), factories=NULL,
+ DIC=NULL, verbose=TRUE, ...){
+
+ # Set up parameters
+ if(is.null(parameters.to.save)){
+ params_long <- colnames(object$samples[[1]])
+ parameters.to.save <- unique(sapply(strsplit(params_long, "\\["), "[", 1))
+ }
+
+ #Set up DIC monitoring
+ if(is.null(DIC)){
+ DIC <- 'deviance' %in% parameters.to.save
+ } else {
+ if(DIC & (!'deviance' %in% parameters.to.save)){
+ parameters.to.save <- c(parameters.to.save, 'deviance')
+ } else if(!DIC & 'deviance' %in% parameters.to.save){
+ parameters.to.save <- parameters.to.save[parameters.to.save != 'deviance']
+ }
+ }
+
+ # Set up MCMC info
+ mcmc.info <- list(n.chains = length(object$samples), n.adapt = n.adapt,
+ n.iter = n.iter, n.burnin = 0,
+ n.thin = ifelse(is.null(n.thin), thin(object$samples), n.thin),
+ n.cores = object$n.cores)
+
+ parallel <- names(object$model[1]) == "cluster1"
+
+ # Run JAGS via rjags
+ rjags_out <- run_rjags(data=NULL, inits=NULL, parameters.to.save, modfile=NULL,
+ mcmc.info, modules, factories, DIC, parallel, !verbose,
+ model.object = object$model, update=TRUE)
+
+ # Report time
+ if(verbose) cat('MCMC took', rjags_out$elapsed.min, 'minutes.\n')
+
+ # Create output object
+ output <- list(samples = order_samples(rjags_out$samples, parameters.to.save),
+ model = rjags_out$m,
+ n.cores = object$n.cores)
+ class(output) <- 'jagsUIbasic'
+
+ return(output)
+}
diff --git a/R/updatebasic.R b/R/updatebasic.R
deleted file mode 100644
index 97fd512..0000000
--- a/R/updatebasic.R
+++ /dev/null
@@ -1,47 +0,0 @@
-
-update.jagsUIbasic <- function(object, parameters.to.save=NULL,
- n.adapt=NULL, n.iter, n.thin=NULL,
- modules=c('glm'), factories=NULL,
- DIC=NULL, verbose=TRUE, ...){
-
- # Set up parameters
- if(is.null(parameters.to.save)){
- params_long <- colnames(object$samples[[1]])
- parameters.to.save <- unique(sapply(strsplit(params_long, "\\["), "[", 1))
- }
-
- #Set up DIC monitoring
- if(is.null(DIC)){
- DIC <- 'deviance' %in% parameters.to.save
- } else {
- if(DIC & (!'deviance' %in% parameters.to.save)){
- parameters.to.save <- c(parameters.to.save, 'deviance')
- } else if(!DIC & 'deviance' %in% parameters.to.save){
- parameters.to.save <- parameters.to.save[parameters.to.save != 'deviance']
- }
- }
-
- # Set up MCMC info
- mcmc.info <- list(n.chains = length(object$samples), n.adapt = n.adapt,
- n.iter = n.iter, n.burnin = 0,
- n.thin = ifelse(is.null(n.thin), thin(object$samples), n.thin),
- n.cores = object$n.cores)
-
- parallel <- names(object$model[1]) == "cluster1"
-
- # Run JAGS via rjags
- rjags_out <- run_rjags(data=NULL, inits=NULL, parameters.to.save, modfile=NULL,
- mcmc.info, modules, factories, DIC, parallel, !verbose,
- model.object = object$model, update=TRUE)
-
- # Report time
- if(verbose) cat('MCMC took', rjags_out$elapsed.min, 'minutes.\n')
-
- # Create output object
- output <- list(samples = order_samples(rjags_out$samples, parameters.to.save),
- model = rjags_out$m,
- n.cores = object$n.cores)
- class(output) <- 'jagsUIbasic'
-
- return(output)
-}
diff --git a/R/utils.R b/R/utils.R
deleted file mode 100644
index 9458c63..0000000
--- a/R/utils.R
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-#--- from process_output ---------------------------------------------------------------------------
-#Check that an object is the right class
-check_class <- function(output){
- if(!inherits(output, "jagsUI")) stop("Requires jagsUI object")
-}
-#------------------------------------------------------------------------------