diff options
author | Ken Kellner <kenkellner@users.noreply.github.com> | 2023-12-05 16:37:40 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-05 16:37:40 -0500 |
commit | 45804c4563f7079284811c0f2a10a9ea5abf4641 (patch) | |
tree | 195fee3a6dcfa8c66d6a69abfa18f36fa88cb8a4 | |
parent | 263dc822dfaa45e0f6629e363c2cd913a57386de (diff) |
Refactor input checking and add tests.
Remove support for setting random seed via function argument, and for character vectors instead of lists as input data.
-rw-r--r-- | .Rbuildignore | 1 | ||||
-rw-r--r-- | DESCRIPTION | 4 | ||||
-rw-r--r-- | Makefile | 21 | ||||
-rw-r--r-- | R/autojags.R | 49 | ||||
-rw-r--r-- | R/datacheck.R | 36 | ||||
-rw-r--r-- | R/geninits.R | 86 | ||||
-rw-r--r-- | R/jags.R | 43 | ||||
-rw-r--r-- | R/jagsbasic.R | 24 | ||||
-rw-r--r-- | R/process_input.R | 155 | ||||
-rw-r--r-- | R/processinput.R | 110 | ||||
-rw-r--r-- | inst/tinytest/test_input_processing.R | 193 | ||||
-rw-r--r-- | inst/tinytest/test_jags.R | 16 | ||||
-rw-r--r-- | inst/tinytest/test_jagsbasic.R | 5 |
13 files changed, 461 insertions, 282 deletions
diff --git a/.Rbuildignore b/.Rbuildignore index e9ca541..8315df6 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -12,3 +12,4 @@ README.md env.R .travis.yml ^\.github$ +Makefile diff --git a/DESCRIPTION b/DESCRIPTION index 8fe4d68..697cf71 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: jagsUI -Version: 1.5.3.9000 -Date: 2023-12-03 +Version: 1.5.3.9001 +Date: 2023-12-05 Title: A Wrapper Around 'rjags' to Streamline 'JAGS' Analyses Authors@R: c( person("Ken", "Kellner", email="contact@kenkellner.com", role=c("cre","aut")), diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e30a6c4 --- /dev/null +++ b/Makefile @@ -0,0 +1,21 @@ +NAME = $(shell grep 'Package:' DESCRIPTION | cut -d ' ' -f2) +VER = $(shell grep 'Version:' DESCRIPTION | cut -d ' ' -f2) + +install: + R CMD INSTALL . + +build: + cd ..; R CMD build $(NAME) + +check: + make build + cd ..; R CMD check $(NAME)_$(VER).tar.gz + +test: + make install + Rscript -e "tinytest::test_package('jagsUI')" + +coverage: + make install + Rscript -e 'covr::report(file="/tmp/jagsUI-report.html")' + firefox /tmp/jagsUI-report.html diff --git a/R/autojags.R b/R/autojags.R index 15b0e9d..4c67f43 100644 --- a/R/autojags.R +++ b/R/autojags.R @@ -3,14 +3,29 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad save.all.iter=FALSE,modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL, bugs.format=FALSE,Rhat.limit=1.1,max.iter=100000,verbose=TRUE){ - #Pass input data and parameter list through error check / processing - data.check <- process.input(data,parameters.to.save,inits,n.chains,(n.burnin + iter.increment), - n.burnin,n.thin,n.cores,DIC=DIC,autojags=TRUE,max.iter=max.iter, - verbose=verbose,parallel=parallel,seed=seed) - data <- data.check$data - parameters.to.save <- data.check$params - inits <- data.check$inits - if(parallel){n.cores <- data.check$n.cores} + if(!is.null(seed)){ + stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) + } + if(n.chains<2) stop('Number of chains must be >1 to calculate Rhat.') + if((max.iter < n.burnin) & verbose){ + old_warn <- options()$warn + options(warn=1) + warning('Maximum iterations includes burn-in and should be larger than burn-in.') + options(warn=old_warn) + } + + # Check input data + inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, + n_chains=n.chains, n_adapt=n.adapt, + n_iter=(n.burnin + iter.increment), + n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, + DIC=DIC, quiet=!verbose, parallel=parallel) + data <- inps_check$data + parameters.to.save <- inps_check$params + inits <- inps_check$inits + mcmc.info <- inps_check$mcmc.info + mcmc.info$end.values <- NULL # this is not saved in autojags for some reason + if(parallel) n.cores <- inps_check$mcmc.info$n.cores #Save start time start.time <- Sys.time() @@ -50,10 +65,11 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad } #Combine mcmc info into list - n.samples <- dim(samples[[1]])[1] * n.chains - mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt=sufficient.adapt,n.iter=(n.burnin + iter.increment),n.burnin,n.thin,n.samples,time) - names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','elapsed.mins') - if(parallel){mcmc.info$n.cores <- n.cores} + mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) + mcmc.info$n.samples <- coda::niter(samples) * n.chains + mcmc.info$n.adapt <- total.adapt + mcmc.info$sufficient.adapt <- sufficient.adapt + mcmc.info$n.iter <- n.burnin + iter.increment test <- test.Rhat(samples,Rhat.limit,codaOnly,verbose=verbose) reach.max <- FALSE @@ -113,7 +129,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad if(!save.all.iter){mcmc.info$n.burnin <- mcmc.info$n.iter} mcmc.info$n.iter <- mcmc.info$n.iter + iter.increment - mcmc.info$n.samples <- dim(samples[[1]])[1] * n.chains + mcmc.info$n.samples <- coda::niter(samples) * n.chains mcmc.info$sufficient.adapt <- sufficient.adapt if(mcmc.info$n.iter>=max.iter){ @@ -123,9 +139,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad } #Get more info about MCMC run - end.time <- Sys.time() - mcmc.info$elapsed.mins <- round(as.numeric(end.time-start.time,units="mins"),digits=3) - date <- start.time + mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) #Reorganize JAGS output to match input parameter order samples <- order_samples(samples, parameters.to.save) @@ -152,8 +166,7 @@ autojags <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.ad output$model <- mod output$parameters <- parameters.to.save output$mcmc.info <- mcmc.info - output$run.date <- date - output$random.seed <- seed + output$run.date <- start.time output$parallel <- parallel output$bugs.format <- bugs.format output$calc.DIC <- DIC diff --git a/R/datacheck.R b/R/datacheck.R deleted file mode 100644 index 7c40473..0000000 --- a/R/datacheck.R +++ /dev/null @@ -1,36 +0,0 @@ - -data.check <- function(x,name,verbose=TRUE){ - - test = FALSE - - if(is.data.frame(x)){ - if(!is.null(dim(x))){ - if(verbose){cat('\nConverting data frame \'',name,'\' to matrix.\n',sep="")} - x = as.matrix(x) - } else { - if(verbose){cat('\nConverting data frame',name,'to vector.\n')} - x = as.vector(x)} - } - - - if (is.numeric(x)&&is.matrix(x)){ - #if(1%in%dim(x)){ - # if(verbose){cat('\nConverting 1-column matrix \'',name,'\' to vector\n',sep="")} - # x = as.vector(x) - #} - test = TRUE - } - - if(is.numeric(x)&&is.array(x)&&!test){ - test = TRUE - } - - if (is.numeric(x)&&is.vector(x)&&!test){ - test = TRUE - } - - if(test){ - return(x) - } else{return('error')} - -} diff --git a/R/geninits.R b/R/geninits.R deleted file mode 100644 index e62ba0f..0000000 --- a/R/geninits.R +++ /dev/null @@ -1,86 +0,0 @@ - -gen.inits <- function(inits,n.chains,seed,parallel){ - - if(!is.null(seed)){ - warning("The 'seed' argument will be deprecated in the next version. You can set it yourself with set.seed() instead.") - #Save old seed if it exists - if(exists('.Random.seed')){ - old.seed <- .Random.seed - } - #Generate seed for each chain - set.seed(seed) - - } - - #Error check and run init function if necessary - if(is.list(inits)){ - if(length(inits)!=n.chains){stop('Length of initial values list != number of chains')} - init.values <- inits - } else if(is.function(inits)){ - init.values <- list() - for (i in 1:n.chains){ - init.values[[i]] <- inits() - } - } else if(is.null(inits)){ - init.values <- NULL - - } else {stop('Invalid initial values. Must be a function or a list with length=n.chains')} - - #Add random seed info if specified - if(!is.null(seed)){ - - init.rand <- floor(runif(n.chains,1,100000)) - - #Restore old seed if it exists - if(exists('old.seed')){ - assign(".Random.seed", old.seed, pos=1) - } - - #Add random seeds to inits - if(is.null(inits)){ - init.values <- vector("list",length=n.chains) - for(i in 1:n.chains){ - init.values[[i]]$.RNG.name="base::Mersenne-Twister" - init.values[[i]]$.RNG.seed=init.rand[i] - } - - } else if(is.list(init.values)){ - for(i in 1:n.chains){ - init.values[[i]]$.RNG.name="base::Mersenne-Twister" - init.values[[i]]$.RNG.seed=init.rand[i] - } - - } else if (is.function(inits)){ - for (i in 1:n.chains){ - init.values[[i]]$.RNG.name="base::Mersenne-Twister" - init.values[[i]]$.RNG.seed=init.rand[i] - } - - } - - - #If seed is not set - } else { - - other.RNG <- all(c(".RNG.name",".RNG.seed")%in%names(init.values[[1]])) - - needs.RNG <- is.null(init.values)|!other.RNG - - #If parallel and no custom RNG has been set, add one. Otherwise all chains will start with same seed. - # if(needs.RNG¶llel){ - if(needs.RNG){ - - init.rand <- floor(runif(n.chains,1,100000)) - - if(is.null(init.values)){init.values <- vector("list",length=n.chains)} - - for(i in 1:n.chains){ - init.values[[i]]$.RNG.name="base::Mersenne-Twister" - init.values[[i]]$.RNG.seed=init.rand[i] - } - - } - } - - return(init.values) -} @@ -4,13 +4,20 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL, bugs.format=FALSE,verbose=TRUE){ - #Pass input data and parameter list through error check / processing - data.check <- process.input(data,parameters.to.save,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=DIC, - verbose=verbose,parallel=parallel,seed=seed) - data <- data.check$data - parameters.to.save <- data.check$params - inits <- data.check$inits - if(parallel){n.cores <- data.check$n.cores} + if(!is.null(seed)){ + stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) + } + + # Check input data + inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, + n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter, + n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, + DIC=DIC, quiet=!verbose, parallel=parallel) + data <- inps_check$data + parameters.to.save <- inps_check$params + inits <- inps_check$inits + mcmc.info <- inps_check$mcmc.info + if(parallel) n.cores <- inps_check$mcmc.info$n.cores #Save start time start.time <- Sys.time() @@ -48,18 +55,13 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain } - #Get more info about MCMC run - end.time <- Sys.time() - time <- round(as.numeric(end.time-start.time,units="mins"),digits=3) - date <- start.time - - #Combine mcmc info into list - n.samples <- dim(samples[[1]])[1] * n.chains - end.values <- samples[(n.samples/n.chains),] - mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt,n.iter,n.burnin,n.thin,n.samples,end.values,time) - names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','end.values','elapsed.mins') - if(parallel){mcmc.info$n.cores <- n.cores} - + #Add mcmc info into list + mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) + mcmc.info$n.samples <- coda::niter(samples) * n.chains + mcmc.info$end.values <- samples[coda::niter(samples),] + mcmc.info$n.adapt <- total.adapt + mcmc.info$sufficient.adapt <- sufficient.adapt + #Reorganize JAGS output to match input parameter order samples <- order_samples(samples, parameters.to.save) @@ -85,8 +87,7 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain output$model <- m output$parameters <- parameters.to.save output$mcmc.info <- mcmc.info - output$run.date <- date - output$random.seed <- seed + output$run.date <- start.time output$parallel <- parallel output$bugs.format <- bugs.format output$calc.DIC <- DIC diff --git a/R/jagsbasic.R b/R/jagsbasic.R index 17d465a..bd7e3f0 100644 --- a/R/jagsbasic.R +++ b/R/jagsbasic.R @@ -2,13 +2,20 @@ jags.basic <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n.adapt=NULL,n.iter,n.burnin=0,n.thin=1, modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,seed=NULL,save.model=FALSE,verbose=TRUE){ - #Pass input data and parameter list through error check / processing - data.check <- process.input(data,parameters.to.save,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=DIC, - verbose=verbose,parallel=parallel,seed=seed) - data <- data.check$data - parameters.to.save <- data.check$params - inits <- data.check$inits - if(parallel){n.cores <- data.check$n.cores} + if(!is.null(seed)){ + stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) + } + + # Check input data + inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, + n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter, + n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, + DIC=DIC, quiet=!verbose, parallel=parallel) + data <- inps_check$data + parameters.to.save <- inps_check$params + inits <- inps_check$inits + mcmc.info <- inps_check$mcmc.info + if(parallel) n.cores <- inps_check$mcmc.info$n.cores #Save start time start.time <- Sys.time() @@ -48,8 +55,7 @@ jags.basic <- function(data,inits=NULL,parameters.to.save,model.file,n.chains,n. } #Get more info about MCMC run - end.time <- Sys.time() - time <- round(as.numeric(end.time-start.time,units="mins"),digits=3) + time <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) if(verbose){cat('MCMC took',time,'minutes.\n')} if(save.model){ diff --git a/R/process_input.R b/R/process_input.R new file mode 100644 index 0000000..39a2558 --- /dev/null +++ b/R/process_input.R @@ -0,0 +1,155 @@ +# Process input---------------------------------------------------------------- +process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burnin, + n_thin, n_cores, DIC, quiet, parallel){ + + if(!quiet){cat('\nProcessing function input.......','\n')} + out <- list(data = check_data(data, quiet), + params = check_params(params, DIC), + inits = check_inits(inits, n_chains), + mcmc.info = check_mcmc_info(n_chains, n_adapt, n_iter, n_burnin, + n_thin, n_cores, parallel) + ) + if(!quiet){cat('\nDone.','\n','\n')} + out +} + + +# Check data list-------------------------------------------------------------- +check_data <- function(inp_data, quiet){ + # Check data is a list + if(!is.list(inp_data)){ + stop("Input data must be a named list", call.=FALSE) + } + # Check list is named + nms <- names(inp_data) + if(is.null(nms) || any(nms == "")){ + stop("All elements of the input data list must be named", call.=FALSE) + } + + # Check individual data elements + out <- lapply(1:length(inp_data), function(i){ + check_data_element(inp_data[[i]], nms[i], quiet) + }) + names(out) <- nms + out +} + + +# Check individual data elements----------------------------------------------- +check_data_element <- function(x, name, quiet){ + # Stop if element is a factor + if(is.factor(x)){ + stop("Element", name, "is a factor. Convert it to numeric.", call.=FALSE) + } + # Try to convert data frame to matrix/vector + if(is.data.frame(x)){ + # Old versions attempted to convert a 1-column data frame to vector + # but did it incorrectly (always converted to matrix). + # This behavior was preserved here. + x <- as.matrix(x) + if(!is.numeric(x)){ + stop("Could not convert data.frame ", name, " to numeric matrix", call.=FALSE) + } + if(!quiet) cat("\nConverted data.frame", name, "to matrix\n") + } + # Final check if element is numeric + if(!is.numeric(x)){ + stop("Element ", name, " is not numeric", call.=FALSE) + } + x +} + + +# Check parameter vector------------------------------------------------------- +check_params <- function(params, DIC){ + if(!(is.vector(params) & is.character(params))){ + stop("parameters.to.save must be a character vector", call.=FALSE) + } + if(DIC & (! "deviance" %in% params)){ + params <- c(params, "deviance") + } + params +} + + +# Check mcmc settings---------------------------------------------------------- +#names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','end.values','elapsed.mins') + +check_mcmc_info <- function(n_chains, n_adapt, n_iter, n_burnin, + n_thin, n_cores, parallel){ + + if(n_iter <= n_burnin){ + stop("Number of iterations must be larger than burn-in", call.=FALSE) + } + # Removed warnings about small numbers of iterations and uneven iterations + + n_cores = check_cores(n_cores, n_chains, parallel) + + # Create list structure and save available elements + out <- list(n.chains = n_chains, n.adapt = n_adapt, sufficient.adapt = NA, + n.iter = n_iter, n.burnin = n_burnin, n.thin = n_thin, + n.samples = NA, end.values = NA, elapsed.mins = NA) + if(parallel) out$n.cores <- n_cores + out +} + + +# Check number of cores-------------------------------------------------------- +check_cores <- function(n_cores, n_chains, parallel){ + if(!parallel) return(NULL) + max_cores <- parallel::detectCores() + # Send this warning right away + old_warn <- options()$warn + options(warn=1) + if(!is.null(n_cores)){ + if(n_cores > max_cores){ + n_cores <- max_cores + warning("n.cores > max available cores. n.cores set to ", max_cores, + call.=FALSE) + } + } else { + if(!is.na(max_cores)){ + n_cores <- min(max_cores, n_chains) + } else { + warning("Couldn't detect number of cores. Setting n.cores = n.chains", + call.=FALSE) + n_cores <- n_chains + } + } + options(warn=old_warn) + n_cores +} + + +# Check initial values--------------------------------------------------------- +check_inits <- function(inits, n_chains){ + if(is.list(inits)){ + if(length(inits) != n_chains){ + stop("inits list must have length equal to the number of chains", + call.=FALSE) + } + } else if(is.function(inits)){ + inits <- lapply(1:n_chains, function(x) inits()) + if(!is.list(inits[[1]])){ + stop("inits function must return list", call.=FALSE) + } + } else if(is.null(inits)){ + inits <- vector("list", n_chains) + } else { + stop("inits must be a list or a function that returns a list", call.=FALSE) + } + + # Setup seeds in each chain, for reproducibility + # Check if they already exist + has_RNG <- all(c(".RNG.name",".RNG.seed") %in% names(inits[[1]])) + # If not add them + if(!has_RNG){ + # Generate random seeds for each chain + chain_seeds <- floor(runif(n_chains, 1, 100000)) + for (i in 1:n_chains){ + inits[[i]]$.RNG.name="base::Mersenne-Twister" + inits[[i]]$.RNG.seed=chain_seeds[i] + } + } + inits +} diff --git a/R/processinput.R b/R/processinput.R deleted file mode 100644 index c208b07..0000000 --- a/R/processinput.R +++ /dev/null @@ -1,110 +0,0 @@ - -process.input = function(x,y,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=FALSE,autojags=FALSE,max.iter=NULL, - verbose=TRUE,parallel=FALSE,seed=NULL){ - if(verbose){cat('\nProcessing function input.......','\n')} - - #Quality control - if(n.iter<=n.burnin){ - stop('Number of iterations must be larger than burn-in.\n') - } - - if(parallel){ - #Set number of clusters - p <- detectCores() - if(is.null(n.cores)){ - if(is.na(p)){ - p <- n.chains - if(verbose){ - options(warn=1) - warning('Could not detect number of cores on the machine. Defaulting to cores used = number of chains.') - options(warn=0,error=NULL) - } - } - n.cores <- min(p,n.chains) - } else { - if(n.cores>p){ - if(verbose){ - options(warn=1) - warning(paste('You have specified more cores (',n.cores,') than the available number of cores on this machine (',p,').\nReducing n.cores to max of ',p,'.',sep="")) - options(warn=0,error=NULL) - } - n.cores <- p - } - } - } - - if(autojags){ - if(n.chains<2){stop('Number of chains must be >1 to calculate Rhat.')} - if(max.iter<n.burnin&verbose){ - options(warn=1) - warning('Maximum iterations includes burn-in and should be larger than burn-in.') - options(warn=0,error=NULL) - } - } - - if(n.thin>1&&(n.iter-n.burnin)<10&&verbose){ - options(warn=1) - warning('The number of iterations is very low; jagsUI may crash. Recommend reducing n.thin to 1 and/or increasing n.iter.') - options(warn=0,error=NULL) - } - - final.chain.length <- (n.iter - n.burnin) / n.thin - even.length <- floor(final.chain.length) == final.chain.length - if(!even.length&verbose){ - options(warn=1) - warning('Number of iterations saved after thinning is not an integer; JAGS will round it up.') - options(warn=0,error=NULL) - } - - #Check if supplied parameter vector is the right format - if((is.character(y)&is.vector(y))){ - } else{stop('The parameters to save must be a vector containing only character strings.\n')} - - #If DIC requested, add deviance to parameters (if not already there) - if(DIC&&(!'deviance'%in%y)){ - params <- c(y,"deviance") - } else {params <- y} - - #Check if supplied data object is the proper format - if(is.list(x)||(is.character(x)&is.vector(x))){ - } else{stop('Input data must be a list of data objects OR a vector of data object names (as strings)\n')} - - if(is.list(x)&&all(sapply(x,is.character))){ - warning("Suppling a list of character strings to the data argument will be deprecated in the future") - x = unlist(x) - } - - if((is.list(x)&&is.null(names(x)))||(is.list(x)&&any(names(x)==""))){ - stop('At least one of the elements in your data list does not have a name\n') - } - - #Convert a supplied vector of characters to a list of data objects - if((is.character(x)&is.vector(x))){ - warning("Suppling a character vector to the data argument will be deprecated in the future") - temp = lapply(x,get,envir = parent.frame(2)) - names(temp) = x - x = temp - } - - #Check each component of data object for issues and fix if possible - for (i in 1:length(x)){ - - if(is.factor(x[[i]])){ - - stop('\nElement \'',names(x[i]) ,'\' in the data list is a factor.','\n','Convert it to a series of dummy/indicator variables or a numeric vector as appropriate.\n') - - } - - process <- data.check(x[[i]],name = names(x[i]),verbose=verbose) - if(!is.na(process[1])&&process[1]=="error"){stop('\nElement \'',names(x[i]) ,'\' in the data list cannot be coerced to one of the','\n','allowed formats (numeric scalar, vector, matrix, or array)\n') - } else{x[[i]] <- process} - - } - - #Get initial values - init.vals <- gen.inits(inits,n.chains,seed,parallel) - - if(verbose){cat('\nDone.','\n','\n')} - return(list(data=x,params=params,inits=init.vals,n.cores=n.cores)) - -} diff --git a/inst/tinytest/test_input_processing.R b/inst/tinytest/test_input_processing.R new file mode 100644 index 0000000..18affd0 --- /dev/null +++ b/inst/tinytest/test_input_processing.R @@ -0,0 +1,193 @@ +process_input <- jagsUI:::process_input +check_inits <- jagsUI:::check_inits + +# Overall structure------------------------------------------------------------ +data1 <- list(a=1, b=c(1,2), c=matrix(rnorm(4), 2,2), + d=array(rnorm(8), c(2,2,2)), e=c(NA, 1)) +test <- process_input(data1, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE) +expect_inherits(test, "list") +expect_equal(names(test), c("data", "params", "inits", "mcmc.info")) + +# Data processing-------------------------------------------------------------- +# Stuff that gets passed through unchanged +data1 <- list(a=1, b=c(1,2), c=matrix(rnorm(4), 2,2), + d=array(rnorm(8), c(2,2,2)), e=c(NA, 1)) +test <- process_input(data1, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE) +expect_identical(test$data, data1) + +# Data frame handling +data2 <- list(a=data.frame(v1=c(1,2)), b=data.frame(v1=c(0,1), v2=c(2,3))) +co <- capture.output( +test <- process_input(data2, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=FALSE, parallel=FALSE) +) +ref_msg <- c("", "Processing function input....... ", "", + "Converted data.frame a to matrix","", + "Converted data.frame b to matrix", "", "Done. ", " ") +expect_equal(co, ref_msg) +expect_equivalent(test$data, list(a=matrix(c(1,2), ncol=1), + b=matrix(c(0:3), ncol=2))) + +# non-numeric data frame errors +data2$v3 <- data.frame(v1=c("a","b")) +expect_error(process_input(data2, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)) + +# Factor in data +data3 <- list(a=1, b=factor(c("1","2"))) +expect_error(process_input(data3, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)) + +# Character in data +data4 <- list(a=1, b=c("a","b")) +expect_error(process_input(data4, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)) + +# Vector with attributes is allowed +vec <- c(1,2) +attr(vec, "test") <- "test" +expect_false(is.vector(vec)) +data5 <- list(vec=vec) +test <- process_input(data5, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE) +expect_equal(data5, test$data) + +# One-column matrix is not converted to vector +data6 <- list(a=matrix(c(1,2), ncol=1)) +test <- process_input(data6, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE) +expect_equal(test$data, data6) + +# Non-list as input errors +t1 <- 1; t2 <- 2 +data7 <- c("t1", "t2") +expect_error(process_input(data7, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)) + +# List without names as input errors +data8 <- list(1, 2) +expect_error(process_input(data8, params="a", NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)) + + +# Parameter vector processing-------------------------------------------------- +dat <- list(a=1, b=2) +pars1 <- c("a", "b") +pars2 <- c("deviance","a", "b") + +# DIC = FALSE +test <- process_input(dat, params=pars1, NULL, 2, 1, 100, 50, 2, + NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE) +expect_equal(pars1, test$params) + +# DIC = TRUE +test <- process_input(dat, params=pars1, NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE) +expect_equal(c(pars1, "deviance"), test$params) + +# Deviance already in vector +test <- process_input(dat, params=pars2, NULL, 2, 1, 100, 50, 2, + NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE) +expect_equal(pars2, test$params) + +# Incorrect format +expect_error(process_input(dat, params=c(1,2), NULL, 2, 1, 100, 50, 2, + NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE)) + + +# MCMC info processing--------------------------------------------------------- +dat <- list(a=1, b=2) +pars1 <- c("a", "b") +# n.iter/n.burnin mismatch +expect_error(process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=100, 2, + NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE)) +expect_error(process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=150, 2, + NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE)) + +# n.cores +# when parallel=FALSE +test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2, + n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE) +expect_true(is.null(test$mcmc.info$n.cores)) + +# when parallel=TRUE defaults to min of nchains and ncores +test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2, + n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=TRUE) +expect_equal(test$mcmc.info$n.cores, 2) + +avail_cores <- parallel::detectCores() +if(avail_cores > 1){ + try_cores <- avail_cores + 1 + n_chain <- try_cores + expect_warning(nul <- capture.output( + test <- process_input(dat, params=pars1, NULL, n_chains=n_chain, 1, + n_iter=100, n_burnin=50, 2, + n_cores=try_cores, DIC=FALSE, quiet=TRUE, parallel=TRUE) + )) + expect_equal(test$mcmc.info$n.cores, avail_cores) +} + +# Initial value processing----------------------------------------------------- +inits1 <- NULL +inits2 <- list(list(a=1, b=2), list(a=3, b=4)) +inits3 <- list(a=1, b=2) +inits4 <- function() list(a=1, b=2) +inits5 <- function() list() +inits6 <- function() 1 +inits7 <- 1 + +# No inits provided +set.seed(123) +test <- check_inits(inits1, n_chains=2) +ref <- list(list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 28758), + list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 78830)) +expect_identical(test, ref) + +# A list of lists +set.seed(123) +test <- check_inits(inits2, n_chains=2) +ref <- list(list(a = 1, b = 2, .RNG.name = "base::Mersenne-Twister", + .RNG.seed = 28758), list(a = 3, b = 4, .RNG.name = "base::Mersenne-Twister", + .RNG.seed = 78830)) +expect_identical(test, ref) +# Wrong number of list elements for number of chains +expect_error(check_inits(inits2, n_chains=3)) + +# A single list +expect_error(check_inits(inits3, n_chains=1)) + +# A function +set.seed(123) +test <- check_inits(inits4, n_chains=2) +ref <- list(list(a = 1, b = 2, .RNG.name = "base::Mersenne-Twister", + .RNG.seed = 28758), list(a = 1, b = 2, .RNG.name = "base::Mersenne-Twister", + .RNG.seed = 78830)) +expect_identical(test, ref) + +# An empty list +set.seed(123) +test <- check_inits(inits5, n_chains=2) +ref <- list(list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 28758), + list(.RNG.name = "base::Mersenne-Twister", .RNG.seed = 78830)) +expect_identical(test, ref) + +# Function but doesn't return list +expect_error(check_inits(inits6, n_chains=2)) + +# A number +expect_error(check_inits(inits7, n_chains=2)) + +# Check exact match when inits is a function with random numbers +set.seed(123) +inits_fun <- function() list(a = rnorm(1), b=rnorm(1)) +#inits_ref <- jagsUI:::gen.inits(inits_fun, seed=NULL, 2) +inits_ref <- list(list(a = -0.560475646552213, b = -0.23017748948328, + .RNG.name = "base::Mersenne-Twister", + .RNG.seed = 55143), list(a = 1.55870831414912, b = 0.070508391424576, + .RNG.name = "base::Mersenne-Twister", .RNG.seed = 45662)) + +set.seed(123) +test <- check_inits(inits_fun, 2) +expect_equal(inits_ref, test) diff --git a/inst/tinytest/test_jags.R b/inst/tinytest/test_jags.R index 1b30505..8295258 100644 --- a/inst/tinytest/test_jags.R +++ b/inst/tinytest/test_jags.R @@ -134,6 +134,18 @@ out <- jags(data = data, inits = inits, parameters.to.save = params, codaOnly = params) expect_equal(nrow(out$summary), 0) +# Saved data and inits--------------------------------------------------------- +set.seed(123) +run_inits <- jagsUI:::check_inits(inits, 3) + +set.seed(123) +out <- jags(data = data, inits = inits, + parameters.to.save = c("alpha","beta"), + model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100, + n.burnin = 50, n.thin = 1, verbose=FALSE, store.data=TRUE) +expect_identical(out$data, data) +expect_identical(out$inits, run_inits) + # Check recovery after process_output errors----------------------------------- # Setting DIC to -999 forces process_output to error for testing expect_message(out <- jags(data = data, inits = inits, @@ -153,6 +165,10 @@ expect_true(all(is.na(out$summary[,"Rhat"]))) expect_true(all(is.na(out$summary[,"n.eff"]))) expect_true(all(out$summary["alpha",3:7] == out$summary["alpha",3])) +# Error when user tries to set seed-------------------------------------------- +expect_error(jags(data = data, inits = inits, parameters.to.save = params, + model.file = modfile, n.chains = 1, n.adapt = 100, n.iter = 100, + n.burnin = 50, n.thin = 1, DIC = FALSE, verbose=FALSE, seed=123)) # Single parameter slice------------------------------------------------------- set.seed(123) diff --git a/inst/tinytest/test_jagsbasic.R b/inst/tinytest/test_jagsbasic.R index 1e03fc4..927bd53 100644 --- a/inst/tinytest/test_jagsbasic.R +++ b/inst/tinytest/test_jagsbasic.R @@ -48,3 +48,8 @@ ref <- readRDS('jagsbasic_ref_update.Rds') expect_identical(names(out2), names(ref)) out2$model <- ref$model expect_equal(out2, ref) + +# Error if seed is set +expect_error(jags.basic(data = data, inits = inits, parameters.to.save = params, + model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100, + n.burnin = 50, n.thin = 2, verbose=FALSE, save.model=TRUE, seed=123)) |