diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-12-05 15:03:01 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-12-05 15:07:52 -0500 |
commit | 0ea3e15296e68a51531e9dca142dbe9a1a376b3e (patch) | |
tree | 1971e51cd5ed9b71074c263ea51814cfdb5bd7b0 | |
parent | e06644ef54978939f7b56839860ac6af57bf4935 (diff) |
Use new input processing with jags
-rw-r--r-- | R/jags.R | 43 | ||||
-rw-r--r-- | R/process_input.R | 16 | ||||
-rw-r--r-- | inst/tinytest/test_input_processing.R | 21 |
3 files changed, 48 insertions, 32 deletions
@@ -4,13 +4,20 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL, bugs.format=FALSE,verbose=TRUE){ - #Pass input data and parameter list through error check / processing - data.check <- process.input(data,parameters.to.save,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=DIC, - verbose=verbose,parallel=parallel,seed=seed) - data <- data.check$data - parameters.to.save <- data.check$params - inits <- data.check$inits - if(parallel){n.cores <- data.check$n.cores} + if(!is.null(seed)){ + stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE) + } + + # Check input data + inps_check <- process_input(data=data, params=parameters.to.save, inits=inits, + n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter, + n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores, + DIC=DIC, quiet=!verbose, parallel=parallel) + data <- inps_check$data + parameters.to.save <- inps_check$params + inits <- inps_check$inits + mcmc.info <- inps_check$mcmc.info + if(parallel) n.cores <- inps_check$mcmc.info$n.cores #Save start time start.time <- Sys.time() @@ -48,18 +55,13 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain } - #Get more info about MCMC run - end.time <- Sys.time() - time <- round(as.numeric(end.time-start.time,units="mins"),digits=3) - date <- start.time - - #Combine mcmc info into list - n.samples <- dim(samples[[1]])[1] * n.chains - end.values <- samples[(n.samples/n.chains),] - mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt,n.iter,n.burnin,n.thin,n.samples,end.values,time) - names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','end.values','elapsed.mins') - if(parallel){mcmc.info$n.cores <- n.cores} - + #Add mcmc info into list + mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) + mcmc.info$n.samples <- coda::niter(samples) * n.chains + mcmc.info$end.values <- samples[coda::niter(samples),] + mcmc.info$n.adapt <- total.adapt + mcmc.info$sufficient.adapt <- sufficient.adapt + #Reorganize JAGS output to match input parameter order samples <- order_samples(samples, parameters.to.save) @@ -85,8 +87,7 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain output$model <- m output$parameters <- parameters.to.save output$mcmc.info <- mcmc.info - output$run.date <- date - output$random.seed <- seed + output$run.date <- start.time output$parallel <- parallel output$bugs.format <- bugs.format output$calc.DIC <- DIC diff --git a/R/process_input.R b/R/process_input.R index c3f219d..39a2558 100644 --- a/R/process_input.R +++ b/R/process_input.R @@ -5,10 +5,9 @@ process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burn if(!quiet){cat('\nProcessing function input.......','\n')} out <- list(data = check_data(data, quiet), params = check_params(params, DIC), - n.cores = check_cores(n_cores, n_chains, parallel), inits = check_inits(inits, n_chains), mcmc.info = check_mcmc_info(n_chains, n_adapt, n_iter, n_burnin, - n_thin) + n_thin, n_cores, parallel) ) if(!quiet){cat('\nDone.','\n','\n')} out @@ -19,7 +18,7 @@ process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burn check_data <- function(inp_data, quiet){ # Check data is a list if(!is.list(inp_data)){ - stop("Input data must be a list", call.=FALSE) + stop("Input data must be a named list", call.=FALSE) } # Check list is named nms <- names(inp_data) @@ -84,10 +83,14 @@ check_mcmc_info <- function(n_chains, n_adapt, n_iter, n_burnin, } # Removed warnings about small numbers of iterations and uneven iterations + n_cores = check_cores(n_cores, n_chains, parallel) + # Create list structure and save available elements - list(n.chains = n_chains, n.adapt = n_adapt, sufficient_adapt = NA, + out <- list(n.chains = n_chains, n.adapt = n_adapt, sufficient.adapt = NA, n.iter = n_iter, n.burnin = n_burnin, n.thin = n_thin, n.samples = NA, end.values = NA, elapsed.mins = NA) + if(parallel) out$n.cores <- n_cores + out } @@ -126,11 +129,10 @@ check_inits <- function(inits, n_chains){ call.=FALSE) } } else if(is.function(inits)){ - test <- inits() - if(!is.list(test)){ + inits <- lapply(1:n_chains, function(x) inits()) + if(!is.list(inits[[1]])){ stop("inits function must return list", call.=FALSE) } - inits <- lapply(1:n_chains, function(x) inits()) } else if(is.null(inits)){ inits <- vector("list", n_chains) } else { diff --git a/inst/tinytest/test_input_processing.R b/inst/tinytest/test_input_processing.R index 28f479a..18affd0 100644 --- a/inst/tinytest/test_input_processing.R +++ b/inst/tinytest/test_input_processing.R @@ -7,7 +7,7 @@ data1 <- list(a=1, b=c(1,2), c=matrix(rnorm(4), 2,2), test <- process_input(data1, params="a", NULL, 2, 1, 100, 50, 2, NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE) expect_inherits(test, "list") -expect_equal(names(test), c("data", "params", "n.cores", "inits", "mcmc.info")) +expect_equal(names(test), c("data", "params", "inits", "mcmc.info")) # Data processing-------------------------------------------------------------- # Stuff that gets passed through unchanged @@ -110,12 +110,12 @@ expect_error(process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=1 # when parallel=FALSE test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2, n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE) -expect_true(is.null(test$n.cores)) +expect_true(is.null(test$mcmc.info$n.cores)) # when parallel=TRUE defaults to min of nchains and ncores test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2, n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=TRUE) -expect_equal(test$n.cores, 2) +expect_equal(test$mcmc.info$n.cores, 2) avail_cores <- parallel::detectCores() if(avail_cores > 1){ @@ -126,7 +126,7 @@ if(avail_cores > 1){ n_iter=100, n_burnin=50, 2, n_cores=try_cores, DIC=FALSE, quiet=TRUE, parallel=TRUE) )) - expect_equal(test$n.cores, avail_cores) + expect_equal(test$mcmc.info$n.cores, avail_cores) } # Initial value processing----------------------------------------------------- @@ -178,3 +178,16 @@ expect_error(check_inits(inits6, n_chains=2)) # A number expect_error(check_inits(inits7, n_chains=2)) + +# Check exact match when inits is a function with random numbers +set.seed(123) +inits_fun <- function() list(a = rnorm(1), b=rnorm(1)) +#inits_ref <- jagsUI:::gen.inits(inits_fun, seed=NULL, 2) +inits_ref <- list(list(a = -0.560475646552213, b = -0.23017748948328, + .RNG.name = "base::Mersenne-Twister", + .RNG.seed = 55143), list(a = 1.55870831414912, b = 0.070508391424576, + .RNG.name = "base::Mersenne-Twister", .RNG.seed = 45662)) + +set.seed(123) +test <- check_inits(inits_fun, 2) +expect_equal(inits_ref, test) |