aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <kenkellner@users.noreply.github.com>2023-12-05 16:37:40 -0500
committerGitHub <noreply@github.com>2023-12-05 16:37:40 -0500
commit45804c4563f7079284811c0f2a10a9ea5abf4641 (patch)
tree195fee3a6dcfa8c66d6a69abfa18f36fa88cb8a4
parent263dc822dfaa45e0f6629e363c2cd913a57386de (diff)
Refactor input checking and add tests.
Remove support for setting random seed via function argument, and for character vectors instead of lists as input data.
-rw-r--r--.Rbuildignore1
-rw-r--r--DESCRIPTION4
-rw-r--r--Makefile21
-rw-r--r--R/autojags.R49
-rw-r--r--R/datacheck.R36
-rw-r--r--R/geninits.R86
-rw-r--r--R/jags.R43
-rw-r--r--R/jagsbasic.R24
-rw-r--r--R/process_input.R155
-rw-r--r--R/processinput.R110
-rw-r--r--inst/tinytest/test_input_processing.R193
-rw-r--r--inst/tinytest/test_jags.R16
-rw-r--r--inst/tinytest/test_jagsbasic.R5
13 files changed, 461 insertions, 282 deletions
diff --git a/.Rbuildignore b/.Rbuildignore
index e9ca541..8315df6 100644
--- a/.Rbuildignore
+++ b/.Rbuildignore
@@ -12,3 +12,4 @@ README.md
env.R
.travis.yml
^\.github$
+Makefile
diff --git a/DESCRIPTION b/DESCRIPTION
index 8fe4d68..697cf71 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -1,6 +1,6 @@
Package: jagsUI
-Version: 1.5.3.9000
-Date: 2023-12-03
+Version: 1.5.3.9001
+Date: 2023-12-05
Title: A Wrapper Around 'rjags' to Streamline 'JAGS' Analyses
Authors@R: c(
person("Ken", "Kellner", email="contact@kenkellner.com", role=c("cre","aut")),
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..e30a6c4
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,21 @@
+NAME = $(shell grep 'Package:' DESCRIPTION | cut -d ' ' -f2)
+VER = $(shell grep 'Version:' DESCRIPTION | cut -d ' ' -f2)
+
+install:
+ R CMD INSTALL .
+
+build:
+ cd ..; R CMD build $(NAME)
+
+check:
+ make build
+ cd ..; R CMD check $(NAME)_$(VER).tar.gz
+
+test:
+ make install
+ Rscript -e "tinytest::test_package('jagsUI')"
+
+coverage:
+ make install
+ Rscript -e 'covr::report(file="/tmp/jagsUI-report.html")'
+ firefox /tmp/jagsUI-report.html
diff --git a/R/autojags.R b/R/autojags.R
index 15b0e9d..4c67f43 100644
--- a/R/autojags.R
+++ b/R/autojags.R
@@ -3,14 +3,29 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
save.all.iter=FALSE,modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL,
bugs.format=FALSE,Rhat.limit=1.1,max.iter=100000,verbose=TRUE){
- #Pass input data and parameter list through error check / processing
- data.check <- process.input(data,parameters.to.save,inits,n.chains,(n.burnin + iter.increment),
- n.burnin,n.thin,n.cores,DIC=DIC,autojags=TRUE,max.iter=max.iter,
- verbose=verbose,parallel=parallel,seed=seed)
- data <- data.check$data
- parameters.to.save <- data.check$params
- inits <- data.check$inits
- if(parallel){n.cores <- data.check$n.cores}
+ if(!is.null(seed)){
+ stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
+ }
+ if(n.chains<2) stop('Number of chains must be >1 to calculate Rhat.')
+ if((max.iter < n.burnin) & verbose){
+ old_warn <- options()$warn
+ options(warn=1)
+ warning('Maximum iterations includes burn-in and should be larger than burn-in.')
+ options(warn=old_warn)
+ }
+
+ # Check input data
+ inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
+ n_chains=n.chains, n_adapt=n.adapt,
+ n_iter=(n.burnin + iter.increment),
+ n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
+ DIC=DIC, quiet=!verbose, parallel=parallel)
+ data <- inps_check$data
+ parameters.to.save <- inps_check$params
+ inits <- inps_check$inits
+ mcmc.info <- inps_check$mcmc.info
+ mcmc.info$end.values <- NULL # this is not saved in autojags for some reason
+ if(parallel) n.cores <- inps_check$mcmc.info$n.cores
#Save start time
start.time <- Sys.time()
@@ -50,10 +65,11 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
}
#Combine mcmc info into list
- n.samples <- dim(samples[[1]])[1] * n.chains
- mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt=sufficient.adapt,n.iter=(n.burnin + iter.increment),n.burnin,n.thin,n.samples,time)
- names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','elapsed.mins')
- if(parallel){mcmc.info$n.cores <- n.cores}
+ mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
+ mcmc.info$n.samples <- coda::niter(samples) * n.chains
+ mcmc.info$n.adapt <- total.adapt
+ mcmc.info$sufficient.adapt <- sufficient.adapt
+ mcmc.info$n.iter <- n.burnin + iter.increment
test <- test.Rhat(samples,Rhat.limit,codaOnly,verbose=verbose)
reach.max <- FALSE
@@ -113,7 +129,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
if(!save.all.iter){mcmc.info$n.burnin <- mcmc.info$n.iter}
mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment
- mcmc.info$n.samples <- dim(samples[[1]])[1] * n.chains
+ mcmc.info$n.samples <- coda::niter(samples) * n.chains
mcmc.info$sufficient.adapt <- sufficient.adapt
if(mcmc.info$n.iter>=max.iter){
@@ -123,9 +139,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
}
#Get more info about MCMC run
- end.time <- Sys.time()
- mcmc.info$elapsed.mins <- round(as.numeric(end.time-start.time,units="mins"),digits=3)
- date <- start.time
+ mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
#Reorganize JAGS output to match input parameter order
samples <- order_samples(samples, parameters.to.save)
@@ -152,8 +166,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad
output$model <- mod
output$parameters <- parameters.to.save
output$mcmc.info <- mcmc.info
- output$run.date <- date
- output$random.seed <- seed
+ output$run.date <- start.time
output$parallel <- parallel
output$bugs.format <- bugs.format
output$calc.DIC <- DIC
diff --git a/R/datacheck.R b/R/datacheck.R
deleted file mode 100644
index 7c40473..0000000
--- a/R/datacheck.R
+++ /dev/null
@@ -1,36 +0,0 @@
-
-data.check <- function(x,name,verbose=TRUE){
-
- test = FALSE
-
- if(is.data.frame(x)){
- if(!is.null(dim(x))){
- if(verbose){cat('\nConverting data frame \'',name,'\' to matrix.\n',sep="")}
- x = as.matrix(x)
- } else {
- if(verbose){cat('\nConverting data frame',name,'to vector.\n')}
- x = as.vector(x)}
- }
-
-
- if (is.numeric(x)&&is.matrix(x)){
- #if(1%in%dim(x)){
- # if(verbose){cat('\nConverting 1-column matrix \'',name,'\' to vector\n',sep="")}
- # x = as.vector(x)
- #}
- test = TRUE
- }
-
- if(is.numeric(x)&&is.array(x)&&!test){
- test = TRUE
- }
-
- if (is.numeric(x)&&is.vector(x)&&!test){
- test = TRUE
- }
-
- if(test){
- return(x)
- } else{return('error')}
-
-}
diff --git a/R/geninits.R b/R/geninits.R
deleted file mode 100644
index e62ba0f..0000000
--- a/R/geninits.R
+++ /dev/null
@@ -1,86 +0,0 @@
-
-gen.inits <- function(inits,n.chains,seed,parallel){
-
- if(!is.null(seed)){
- warning("The 'seed' argument will be deprecated in the next version. You can set it yourself with set.seed() instead.")
- #Save old seed if it exists
- if(exists('.Random.seed')){
- old.seed <- .Random.seed
- }
- #Generate seed for each chain
- set.seed(seed)
-
- }
-
- #Error check and run init function if necessary
- if(is.list(inits)){
- if(length(inits)!=n.chains){stop('Length of initial values list != number of chains')}
- init.values <- inits
- } else if(is.function(inits)){
- init.values <- list()
- for (i in 1:n.chains){
- init.values[[i]] <- inits()
- }
- } else if(is.null(inits)){
- init.values <- NULL
-
- } else {stop('Invalid initial values. Must be a function or a list with length=n.chains')}
-
- #Add random seed info if specified
- if(!is.null(seed)){
-
- init.rand <- floor(runif(n.chains,1,100000))
-
- #Restore old seed if it exists
- if(exists('old.seed')){
- assign(".Random.seed", old.seed, pos=1)
- }
-
- #Add random seeds to inits
- if(is.null(inits)){
- init.values <- vector("list",length=n.chains)
- for(i in 1:n.chains){
- init.values[[i]]$.RNG.name="base::Mersenne-Twister"
- init.values[[i]]$.RNG.seed=init.rand[i]
- }
-
- } else if(is.list(init.values)){
- for(i in 1:n.chains){
- init.values[[i]]$.RNG.name="base::Mersenne-Twister"
- init.values[[i]]$.RNG.seed=init.rand[i]
- }
-
- } else if (is.function(inits)){
- for (i in 1:n.chains){
- init.values[[i]]$.RNG.name="base::Mersenne-Twister"
- init.values[[i]]$.RNG.seed=init.rand[i]
- }
-
- }
-
-
- #If seed is not set
- } else {
-
- other.RNG <- all(c(".RNG.name",".RNG.seed")%in%names(init.values[[1]]))
-
- needs.RNG <- is.null(init.values)|!other.RNG
-
- #If parallel and no custom RNG has been set, add one. Otherwise all chains will start with same seed.
- # if(needs.RNG&parallel){
- if(needs.RNG){
-
- init.rand <- floor(runif(n.chains,1,100000))
-
- if(is.null(init.values)){init.values <- vector("list",length=n.chains)}
-
- for(i in 1:n.chains){
- init.values[[i]]$.RNG.name="base::Mersenne-Twister"
- init.values[[i]]$.RNG.seed=init.rand[i]
- }
-
- }
- }
-
- return(init.values)
-}
diff --git a/R/jags.R b/R/jags.R
index 9523575..5af4a6c 100644
--- a/R/jags.R
+++ b/R/jags.R
@@ -4,13 +4,20 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain
modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL,
bugs.format=FALSE,verbose=TRUE){
- #Pass input data and parameter list through error check / processing
- data.check <- process.input(data,parameters.to.save,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=DIC,
- verbose=verbose,parallel=parallel,seed=seed)
- data <- data.check$data
- parameters.to.save <- data.check$params
- inits <- data.check$inits
- if(parallel){n.cores <- data.check$n.cores}
+ if(!is.null(seed)){
+ stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
+ }
+
+ # Check input data
+ inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
+ n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter,
+ n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
+ DIC=DIC, quiet=!verbose, parallel=parallel)
+ data <- inps_check$data
+ parameters.to.save <- inps_check$params
+ inits <- inps_check$inits
+ mcmc.info <- inps_check$mcmc.info
+ if(parallel) n.cores <- inps_check$mcmc.info$n.cores
#Save start time
start.time <- Sys.time()
@@ -48,18 +55,13 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain
}
- #Get more info about MCMC run
- end.time <- Sys.time()
- time <- round(as.numeric(end.time-start.time,units="mins"),digits=3)
- date <- start.time
-
- #Combine mcmc info into list
- n.samples <- dim(samples[[1]])[1] * n.chains
- end.values <- samples[(n.samples/n.chains),]
- mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt,n.iter,n.burnin,n.thin,n.samples,end.values,time)
- names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','end.values','elapsed.mins')
- if(parallel){mcmc.info$n.cores <- n.cores}
-
+ #Add mcmc info into list
+ mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
+ mcmc.info$n.samples <- coda::niter(samples) * n.chains
+ mcmc.info$end.values <- samples[coda::niter(samples),]
+ mcmc.info$n.adapt <- total.adapt
+ mcmc.info$sufficient.adapt <- sufficient.adapt
+
#Reorganize JAGS output to match input parameter order
samples <- order_samples(samples, parameters.to.save)
@@ -85,8 +87,7 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain
output$model <- m
output$parameters <- parameters.to.save
output$mcmc.info <- mcmc.info
- output$run.date <- date
- output$random.seed <- seed
+ output$run.date <- start.time
output$parallel <- parallel
output$bugs.format <- bugs.format
output$calc.DIC <- DIC
diff --git a/R/jagsbasic.R b/R/jagsbasic.R
index 17d465a..bd7e3f0 100644
--- a/R/jagsbasic.R
+++ b/R/jagsbasic.R
@@ -2,13 +2,20 @@
jags.basic <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,n.iter,n.burnin=0,n.thin=1,
modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,seed=NULL,save.model=FALSE,verbose=TRUE){
- #Pass input data and parameter list through error check / processing
- data.check <- process.input(data,parameters.to.save,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=DIC,
- verbose=verbose,parallel=parallel,seed=seed)
- data <- data.check$data
- parameters.to.save <- data.check$params
- inits <- data.check$inits
- if(parallel){n.cores <- data.check$n.cores}
+ if(!is.null(seed)){
+ stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
+ }
+
+ # Check input data
+ inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
+ n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter,
+ n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
+ DIC=DIC, quiet=!verbose, parallel=parallel)
+ data <- inps_check$data
+ parameters.to.save <- inps_check$params
+ inits <- inps_check$inits
+ mcmc.info <- inps_check$mcmc.info
+ if(parallel) n.cores <- inps_check$mcmc.info$n.cores
#Save start time
start.time <- Sys.time()
@@ -48,8 +55,7 @@ jags.basic <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.
}
#Get more info about MCMC run
- end.time <- Sys.time()
- time <- round(as.numeric(end.time-start.time,units="mins"),digits=3)
+ time <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
if(verbose){cat('MCMC took',time,'minutes.\n')}
if(save.model){
diff --git a/R/process_input.R b/R/process_input.R
new file mode 100644
index 0000000..39a2558
--- /dev/null
+++ b/R/process_input.R
@@ -0,0 +1,155 @@
+# Process input----------------------------------------------------------------
+process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burnin,
+ n_thin, n_cores, DIC, quiet, parallel){
+
+ if(!quiet){cat('\nProcessing function input.......','\n')}
+ out <- list(data = check_data(data, quiet),
+ params = check_params(params, DIC),
+ inits = check_inits(inits, n_chains),
+ mcmc.info = check_mcmc_info(n_chains, n_adapt, n_iter, n_burnin,
+ n_thin, n_cores, parallel)
+ )
+ if(!quiet){cat('\nDone.','\n','\n')}
+ out
+}
+
+
+# Check data list--------------------------------------------------------------
+check_data <- function(inp_data, quiet){
+ # Check data is a list
+ if(!is.list(inp_data)){
+ stop("Input data must be a named list", call.=FALSE)
+ }
+ # Check list is named
+ nms <- names(inp_data)
+ if(is.null(nms) || any(nms == "")){
+ stop("All elements of the input data list must be named", call.=FALSE)
+ }
+
+ # Check individual data elements
+ out <- lapply(1:length(inp_data), function(i){
+ check_data_element(inp_data[[i]], nms[i], quiet)
+ })
+ names(out) <- nms
+ out
+}
+
+
+# Check individual data elements-----------------------------------------------
+check_data_element <- function(x, name, quiet){
+ # Stop if element is a factor
+ if(is.factor(x)){
+ stop("Element", name, "is a factor. Convert it to numeric.", call.=FALSE)
+ }
+ # Try to convert data frame to matrix/vector
+ if(is.data.frame(x)){
+ # Old versions attempted to convert a 1-column data frame to vector
+ # but did it incorrectly (always converted to matrix).
+ # This behavior was preserved here.
+ x <- as.matrix(x)
+ if(!is.numeric(x)){
+ stop("Could not convert data.frame ", name, " to numeric matrix", call.=FALSE)
+ }
+ if(!quiet) cat("\nConverted data.frame", name, "to matrix\n")
+ }
+ # Final check if element is numeric
+ if(!is.numeric(x)){
+ stop("Element ", name, " is not numeric", call.=FALSE)
+ }
+ x
+}
+
+
+# Check parameter vector-------------------------------------------------------
+check_params <- function(params, DIC){
+ if(!(is.vector(params) & is.character(params))){
+ stop("parameters.to.save must be a character vector", call.=FALSE)
+ }
+ if(DIC & (! "deviance" %in% params)){
+ params <- c(params, "deviance")
+ }
+ params
+}
+
+
+# Check mcmc settings----------------------------------------------------------
+#names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','end.values','elapsed.mins')
+
+check_mcmc_info <- function(n_chains, n_adapt, n_iter, n_burnin,
+ n_thin, n_cores, parallel){
+
+ if(n_iter <= n_burnin){
+ stop("Number of iterations must be larger than burn-in", call.=FALSE)
+ }
+ # Removed warnings about small numbers of iterations and uneven iterations
+
+ n_cores = check_cores(n_cores, n_chains, parallel)
+
+ # Create list structure and save available elements
+ out <- list(n.chains = n_chains, n.adapt = n_adapt, sufficient.adapt = NA,
+ n.iter = n_iter, n.burnin = n_burnin, n.thin = n_thin,
+ n.samples = NA, end.values = NA, elapsed.mins = NA)
+ if(parallel) out$n.cores <- n_cores
+ out
+}
+
+
+# Check number of cores--------------------------------------------------------
+check_cores <- function(n_cores, n_chains, parallel){
+ if(!parallel) return(NULL)
+ max_cores <- parallel::detectCores()
+ # Send this warning right away
+ old_warn <- options()$warn
+ options(warn=1)
+ if(!is.null(n_cores)){
+ if(n_cores > max_cores){
+ n_cores <- max_cores
+ warning("n.cores > max available cores. n.cores set to ", max_cores,
+ call.=FALSE)
+ }
+ } else {
+ if(!is.na(max_cores)){
+ n_cores <- min(max_cores, n_chains)
+ } else {
+ warning("Couldn't detect number of cores. Setting n.cores = n.chains",
+ call.=FALSE)
+ n_cores <- n_chains
+ }
+ }
+ options(warn=old_warn)
+ n_cores
+}
+
+
+# Check initial values---------------------------------------------------------
+check_inits <- function(inits, n_chains){
+ if(is.list(inits)){
+ if(length(inits) != n_chains){
+ stop("inits list must have length equal to the number of chains",
+ call.=FALSE)
+ }
+ } else if(is.function(inits)){
+ inits <- lapply(1:n_chains, function(x) inits())
+ if(!is.list(inits[[1]])){
+ stop("inits function must return list", call.=FALSE)
+ }
+ } else if(is.null(inits)){
+ inits <- vector("list", n_chains)
+ } else {
+ stop("inits must be a list or a function that returns a list", call.=FALSE)
+ }
+
+ # Setup seeds in each chain, for reproducibility
+ # Check if they already exist
+ has_RNG <- all(c(".RNG.name",".RNG.seed") %in% names(inits[[1]]))
+ # If not add them
+ if(!has_RNG){
+ # Generate random seeds for each chain
+ chain_seeds <- floor(runif(n_chains, 1, 100000))
+ for (i in 1:n_chains){
+ inits[[i]]$.RNG.name="base::Mersenne-Twister"
+ inits[[i]]$.RNG.seed=chain_seeds[i]
+ }
+ }
+ inits
+}
diff --git a/R/processinput.R b/R/processinput.R
deleted file mode 100644
index c208b07..0000000
--- a/R/processinput.R
+++ /dev/null
@@ -1,110 +0,0 @@
-
-process.input = function(x,y,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=FALSE,autojags=FALSE,max.iter=NULL,
- verbose=TRUE,parallel=FALSE,seed=NULL){
- if(verbose){cat('\nProcessing function input.......','\n')}
-
- #Quality control
- if(n.iter<=n.burnin){
- stop('Number of iterations must be larger than burn-in.\n')
- }
-
- if(parallel){
- #Set number of clusters
- p <- detectCores()
- if(is.null(n.cores)){
- if(is.na(p)){
- p <- n.chains
- if(verbose){
- options(warn=1)
- warning('Could not detect number of cores on the machine. Defaulting to cores used = number of chains.')
- options(warn=0,error=NULL)
- }
- }
- n.cores <- min(p,n.chains)
- } else {
- if(n.cores>p){
- if(verbose){
- options(warn=1)
- warning(paste('You have specified more cores (',n.cores,') than the available number of cores on this machine (',p,').\nReducing n.cores to max of ',p,'.',sep=""))
- options(warn=0,error=NULL)
- }
- n.cores <- p
- }
- }
- }
-
- if(autojags){
- if(n.chains<2){stop('Number of chains must be >1 to calculate Rhat.')}
- if(max.iter<n.burnin&verbose){
- options(warn=1)
- warning('Maximum iterations includes burn-in and should be larger than burn-in.')
- options(warn=0,error=NULL)
- }
- }
-
- if(n.thin>1&&(n.iter-n.burnin)<10&&verbose){
- options(warn=1)
- warning('The number of iterations is very low; jagsUI may crash. Recommend reducing n.thin to 1 and/or increasing n.iter.')
- options(warn=0,error=NULL)
- }
-
- final.chain.length <- (n.iter - n.burnin) / n.thin
- even.length <- floor(final.chain.length) == final.chain.length
- if(!even.length&verbose){
- options(warn=1)
- warning('Number of iterations saved after thinning is not an integer; JAGS will round it up.')
- options(warn=0,error=NULL)
- }
-
- #Check if supplied parameter vector is the right format
- if((is.character(y)&is.vector(y))){
- } else{stop('The parameters to save must be a vector containing only character strings.\n')}
-
- #If DIC requested, add deviance to parameters (if not already there)
- if(DIC&&(!'deviance'%in%y)){
- params <- c(y,"deviance")
- } else {params <- y}
-
- #Check if supplied data object is the proper format
- if(is.list(x)||(is.character(x)&is.vector(x))){
- } else{stop('Input data must be a list of data objects OR a vector of data object names (as strings)\n')}
-
- if(is.list(x)&&all(sapply(x,is.character))){
- warning("Suppling a list of character strings to the data argument will be deprecated in the future")
- x = unlist(x)
- }
-
- if((is.list(x)&&is.null(names(x)))||(is.list(x)&&any(names(x)==""))){
- stop('At least one of the elements in your data list does not have a name\n')
- }
-
- #Convert a supplied vector of characters to a list of data objects
- if((is.character(x)&is.vector(x))){
- warning("Suppling a character vector to the data argument will be deprecated in the future")
- temp = lapply(x,get,envir = parent.frame(2))
- names(temp) = x
- x = temp
- }
-
- #Check each component of data object for issues and fix if possible
- for (i in 1:length(x)){
-
- if(is.factor(x[[i]])){
-
- stop('\nElement \'',names(x[i]) ,'\' in the data list is a factor.','\n','Convert it to a series of dummy/indicator variables or a numeric vector as appropriate.\n')
-
- }
-
- process <- data.check(x[[i]],name = names(x[i]),verbose=verbose)
- if(!is.na(process[1])&&process[1]=="error"){stop('\nElement \'',names(x[i]) ,'\' in the data list cannot be coerced to one of the','\n','allowed formats (numeric scalar, vector, matrix, or array)\n')
- } else{x[[i]] <- process}
-
- }
-
- #Get initial values
- init.vals <- gen.inits(inits,n.chains,seed,parallel)
-
- if(verbose){cat('\nDone.','\n','\n')}
- return(list(data=x,params=params,inits=init.vals,n.cores=n.cores))
-
-}
diff --git a/inst/tinytest/test_input_processing.R b/inst/tinytest/test_input_processing.R
new file mode 100644
index 0000000..18affd0
--- /dev/null
+++ b/inst/tinytest/test_input_processing.R
@@ -0,0 +1,193 @@
+process_input <- jagsUI:::process_input
+check_inits <- jagsUI:::check_inits
+
+# Overall structure------------------------------------------------------------
+data1 <- list(a=1, b=c(1,2), c=matrix(rnorm(4), 2,2),
+ d=array(rnorm(8), c(2,2,2)), e=c(NA, 1))
+test <- process_input(data1, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)
+expect_inherits(test, "list")
+expect_equal(names(test), c("data", "params", "inits", "mcmc.info"))
+
+# Data processing--------------------------------------------------------------
+# Stuff that gets passed through unchanged
+data1 <- list(a=1, b=c(1,2), c=matrix(rnorm(4), 2,2),
+ d=array(rnorm(8), c(2,2,2)), e=c(NA, 1))
+test <- process_input(data1, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)
+expect_identical(test$data, data1)
+
+# Data frame handling
+data2 <- list(a=data.frame(v1=c(1,2)), b=data.frame(v1=c(0,1), v2=c(2,3)))
+co <- capture.output(
+test <- process_input(data2, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=FALSE, parallel=FALSE)
+)
+ref_msg <- c("", "Processing function input....... ", "",
+ "Converted data.frame a to matrix","",
+ "Converted data.frame b to matrix", "", "Done. ", " ")
+expect_equal(co, ref_msg)
+expect_equivalent(test$data, list(a=matrix(c(1,2), ncol=1),
+ b=matrix(c(0:3), ncol=2)))
+
+# non-numeric data frame errors
+data2$v3 <- data.frame(v1=c("a","b"))
+expect_error(process_input(data2, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE))
+
+# Factor in data
+data3 <- list(a=1, b=factor(c("1","2")))
+expect_error(process_input(data3, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE))
+
+# Character in data
+data4 <- list(a=1, b=c("a","b"))
+expect_error(process_input(data4, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE))
+
+# Vector with attributes is allowed
+vec <- c(1,2)
+attr(vec, "test") <- "test"
+expect_false(is.vector(vec))
+data5 <- list(vec=vec)
+test <- process_input(data5, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)
+expect_equal(data5, test$data)
+
+# One-column matrix is not converted to vector
+data6 <- list(a=matrix(c(1,2), ncol=1))
+test <- process_input(data6, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)
+expect_equal(test$data, data6)
+
+# Non-list as input errors
+t1 <- 1; t2 <- 2
+data7 <- c("t1", "t2")
+expect_error(process_input(data7, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE))
+
+# List without names as input errors
+data8 <- list(1, 2)
+expect_error(process_input(data8, params="a", NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE))
+
+
+# Parameter vector processing--------------------------------------------------
+dat <- list(a=1, b=2)
+pars1 <- c("a", "b")
+pars2 <- c("deviance","a", "b")
+
+# DIC = FALSE
+test <- process_input(dat, params=pars1, NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE)
+expect_equal(pars1, test$params)
+
+# DIC = TRUE
+test <- process_input(dat, params=pars1, NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)
+expect_equal(c(pars1, "deviance"), test$params)
+
+# Deviance already in vector
+test <- process_input(dat, params=pars2, NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)
+expect_equal(pars2, test$params)
+
+# Incorrect format
+expect_error(process_input(dat, params=c(1,2), NULL, 2, 1, 100, 50, 2,
+ NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE))
+
+
+# MCMC info processing---------------------------------------------------------
+dat <- list(a=1, b=2)
+pars1 <- c("a", "b")
+# n.iter/n.burnin mismatch
+expect_error(process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=100, 2,
+ NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE))
+expect_error(process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=150, 2,
+ NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE))
+
+# n.cores
+# when parallel=FALSE
+test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2,
+ n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE)
+expect_true(is.null(test$mcmc.info$n.cores))
+
+# when parallel=TRUE defaults to min of nchains and ncores
+test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2,
+ n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=TRUE)
+expect_equal(test$mcmc.info$n.cores, 2)
+
+avail_cores <- parallel::detectCores()
+if(avail_cores > 1){
+ try_cores <- avail_cores + 1
+ n_chain <- try_cores
+ expect_warning(nul <- capture.output(
+ test <- process_input(dat, params=pars1, NULL, n_chains=n_chain, 1,
+ n_iter=100, n_burnin=50, 2,
+ n_cores=try_cores, DIC=FALSE, quiet=TRUE, parallel=TRUE)
+ ))
+ expect_equal(test$mcmc.info$n.cores, avail_cores)
+}
+
+# Initial value processing-----------------------------------------------------
+inits1 <- NULL
+inits2 <- list(list(a=1, b=2), list(a=3, b=4))
+inits3 <- list(a=1, b=2)
+inits4 <- function() list(a=1, b=2)
+inits5 <- function() list()
+inits6 <- function() 1
+inits7 <- 1
+
+# No inits provided
+set.seed(123)
+test <- check_inits(inits1, n_chains=2)
+ref <- list(list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 28758),
+ list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 78830))
+expect_identical(test, ref)
+
+# A list of lists
+set.seed(123)
+test <- check_inits(inits2, n_chains=2)
+ref <- list(list(a = 1, b = 2, .RNG.name = "base::Mersenne-Twister",
+ .RNG.seed = 28758), list(a = 3, b = 4, .RNG.name = "base::Mersenne-Twister",
+ .RNG.seed = 78830))
+expect_identical(test, ref)
+# Wrong number of list elements for number of chains
+expect_error(check_inits(inits2, n_chains=3))
+
+# A single list
+expect_error(check_inits(inits3, n_chains=1))
+
+# A function
+set.seed(123)
+test <- check_inits(inits4, n_chains=2)
+ref <- list(list(a = 1, b = 2, .RNG.name = "base::Mersenne-Twister",
+ .RNG.seed = 28758), list(a = 1, b = 2, .RNG.name = "base::Mersenne-Twister",
+ .RNG.seed = 78830))
+expect_identical(test, ref)
+
+# An empty list
+set.seed(123)
+test <- check_inits(inits5, n_chains=2)
+ref <- list(list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 28758),
+ list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 78830))
+expect_identical(test, ref)
+
+# Function but doesn't return list
+expect_error(check_inits(inits6, n_chains=2))
+
+# A number
+expect_error(check_inits(inits7, n_chains=2))
+
+# Check exact match when inits is a function with random numbers
+set.seed(123)
+inits_fun <- function() list(a = rnorm(1), b=rnorm(1))
+#inits_ref <- jagsUI:::gen.inits(inits_fun, seed=NULL, 2)
+inits_ref <- list(list(a = -0.560475646552213, b = -0.23017748948328,
+ .RNG.name = "base::Mersenne-Twister",
+ .RNG.seed = 55143), list(a = 1.55870831414912, b = 0.070508391424576,
+ .RNG.name = "base::Mersenne-Twister", .RNG.seed = 45662))
+
+set.seed(123)
+test <- check_inits(inits_fun, 2)
+expect_equal(inits_ref, test)
diff --git a/inst/tinytest/test_jags.R b/inst/tinytest/test_jags.R
index 1b30505..8295258 100644
--- a/inst/tinytest/test_jags.R
+++ b/inst/tinytest/test_jags.R
@@ -134,6 +134,18 @@ out <- jags(data = data, inits = inits, parameters.to.save = params,
codaOnly = params)
expect_equal(nrow(out$summary), 0)
+# Saved data and inits---------------------------------------------------------
+set.seed(123)
+run_inits <- jagsUI:::check_inits(inits, 3)
+
+set.seed(123)
+out <- jags(data = data, inits = inits,
+ parameters.to.save = c("alpha","beta"),
+ model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100,
+ n.burnin = 50, n.thin = 1, verbose=FALSE, store.data=TRUE)
+expect_identical(out$data, data)
+expect_identical(out$inits, run_inits)
+
# Check recovery after process_output errors-----------------------------------
# Setting DIC to -999 forces process_output to error for testing
expect_message(out <- jags(data = data, inits = inits,
@@ -153,6 +165,10 @@ expect_true(all(is.na(out$summary[,"Rhat"])))
expect_true(all(is.na(out$summary[,"n.eff"])))
expect_true(all(out$summary["alpha",3:7] == out$summary["alpha",3]))
+# Error when user tries to set seed--------------------------------------------
+expect_error(jags(data = data, inits = inits, parameters.to.save = params,
+ model.file = modfile, n.chains = 1, n.adapt = 100, n.iter = 100,
+ n.burnin = 50, n.thin = 1, DIC = FALSE, verbose=FALSE, seed=123))
# Single parameter slice-------------------------------------------------------
set.seed(123)
diff --git a/inst/tinytest/test_jagsbasic.R b/inst/tinytest/test_jagsbasic.R
index 1e03fc4..927bd53 100644
--- a/inst/tinytest/test_jagsbasic.R
+++ b/inst/tinytest/test_jagsbasic.R
@@ -48,3 +48,8 @@ ref <- readRDS('jagsbasic_ref_update.Rds')
expect_identical(names(out2), names(ref))
out2$model <- ref$model
expect_equal(out2, ref)
+
+# Error if seed is set
+expect_error(jags.basic(data = data, inits = inits, parameters.to.save = params,
+ model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100,
+ n.burnin = 50, n.thin = 2, verbose=FALSE, save.model=TRUE, seed=123))