aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2023-12-05 15:03:01 -0500
committerKen Kellner <ken@kenkellner.com>2023-12-05 15:07:52 -0500
commit0ea3e15296e68a51531e9dca142dbe9a1a376b3e (patch)
tree1971e51cd5ed9b71074c263ea51814cfdb5bd7b0
parente06644ef54978939f7b56839860ac6af57bf4935 (diff)
Use new input processing with jags
-rw-r--r--R/jags.R43
-rw-r--r--R/process_input.R16
-rw-r--r--inst/tinytest/test_input_processing.R21
3 files changed, 48 insertions, 32 deletions
diff --git a/R/jags.R b/R/jags.R
index 9523575..5af4a6c 100644
--- a/R/jags.R
+++ b/R/jags.R
@@ -4,13 +4,20 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain
modules=c('glm'),factories=NULL,parallel=FALSE,n.cores=NULL,DIC=TRUE,store.data=FALSE,codaOnly=FALSE,seed=NULL,
bugs.format=FALSE,verbose=TRUE){
- #Pass input data and parameter list through error check / processing
- data.check <- process.input(data,parameters.to.save,inits,n.chains,n.iter,n.burnin,n.thin,n.cores,DIC=DIC,
- verbose=verbose,parallel=parallel,seed=seed)
- data <- data.check$data
- parameters.to.save <- data.check$params
- inits <- data.check$inits
- if(parallel){n.cores <- data.check$n.cores}
+ if(!is.null(seed)){
+ stop("The seed argument is no longer supported, use set.seed() instead", call.=FALSE)
+ }
+
+ # Check input data
+ inps_check <- process_input(data=data, params=parameters.to.save, inits=inits,
+ n_chains=n.chains, n_adapt=n.adapt, n_iter=n.iter,
+ n_burnin=n.burnin, n_thin=n.thin, n_cores=n.cores,
+ DIC=DIC, quiet=!verbose, parallel=parallel)
+ data <- inps_check$data
+ parameters.to.save <- inps_check$params
+ inits <- inps_check$inits
+ mcmc.info <- inps_check$mcmc.info
+ if(parallel) n.cores <- inps_check$mcmc.info$n.cores
#Save start time
start.time <- Sys.time()
@@ -48,18 +55,13 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain
}
- #Get more info about MCMC run
- end.time <- Sys.time()
- time <- round(as.numeric(end.time-start.time,units="mins"),digits=3)
- date <- start.time
-
- #Combine mcmc info into list
- n.samples <- dim(samples[[1]])[1] * n.chains
- end.values <- samples[(n.samples/n.chains),]
- mcmc.info <- list(n.chains,n.adapt=total.adapt,sufficient.adapt,n.iter,n.burnin,n.thin,n.samples,end.values,time)
- names(mcmc.info) <- c('n.chains','n.adapt','sufficient.adapt','n.iter','n.burnin','n.thin','n.samples','end.values','elapsed.mins')
- if(parallel){mcmc.info$n.cores <- n.cores}
-
+ #Add mcmc info into list
+ mcmc.info$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3)
+ mcmc.info$n.samples <- coda::niter(samples) * n.chains
+ mcmc.info$end.values <- samples[coda::niter(samples),]
+ mcmc.info$n.adapt <- total.adapt
+ mcmc.info$sufficient.adapt <- sufficient.adapt
+
#Reorganize JAGS output to match input parameter order
samples <- order_samples(samples, parameters.to.save)
@@ -85,8 +87,7 @@ jagsUI <- jags <- function(data,inits=NULL,parameters.to.save,model.file,n.chain
output$model <- m
output$parameters <- parameters.to.save
output$mcmc.info <- mcmc.info
- output$run.date <- date
- output$random.seed <- seed
+ output$run.date <- start.time
output$parallel <- parallel
output$bugs.format <- bugs.format
output$calc.DIC <- DIC
diff --git a/R/process_input.R b/R/process_input.R
index c3f219d..39a2558 100644
--- a/R/process_input.R
+++ b/R/process_input.R
@@ -5,10 +5,9 @@ process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burn
if(!quiet){cat('\nProcessing function input.......','\n')}
out <- list(data = check_data(data, quiet),
params = check_params(params, DIC),
- n.cores = check_cores(n_cores, n_chains, parallel),
inits = check_inits(inits, n_chains),
mcmc.info = check_mcmc_info(n_chains, n_adapt, n_iter, n_burnin,
- n_thin)
+ n_thin, n_cores, parallel)
)
if(!quiet){cat('\nDone.','\n','\n')}
out
@@ -19,7 +18,7 @@ process_input <- function(data, params, inits, n_chains, n_adapt, n_iter, n_burn
check_data <- function(inp_data, quiet){
# Check data is a list
if(!is.list(inp_data)){
- stop("Input data must be a list", call.=FALSE)
+ stop("Input data must be a named list", call.=FALSE)
}
# Check list is named
nms <- names(inp_data)
@@ -84,10 +83,14 @@ check_mcmc_info <- function(n_chains, n_adapt, n_iter, n_burnin,
}
# Removed warnings about small numbers of iterations and uneven iterations
+ n_cores = check_cores(n_cores, n_chains, parallel)
+
# Create list structure and save available elements
- list(n.chains = n_chains, n.adapt = n_adapt, sufficient_adapt = NA,
+ out <- list(n.chains = n_chains, n.adapt = n_adapt, sufficient.adapt = NA,
n.iter = n_iter, n.burnin = n_burnin, n.thin = n_thin,
n.samples = NA, end.values = NA, elapsed.mins = NA)
+ if(parallel) out$n.cores <- n_cores
+ out
}
@@ -126,11 +129,10 @@ check_inits <- function(inits, n_chains){
call.=FALSE)
}
} else if(is.function(inits)){
- test <- inits()
- if(!is.list(test)){
+ inits <- lapply(1:n_chains, function(x) inits())
+ if(!is.list(inits[[1]])){
stop("inits function must return list", call.=FALSE)
}
- inits <- lapply(1:n_chains, function(x) inits())
} else if(is.null(inits)){
inits <- vector("list", n_chains)
} else {
diff --git a/inst/tinytest/test_input_processing.R b/inst/tinytest/test_input_processing.R
index 28f479a..18affd0 100644
--- a/inst/tinytest/test_input_processing.R
+++ b/inst/tinytest/test_input_processing.R
@@ -7,7 +7,7 @@ data1 <- list(a=1, b=c(1,2), c=matrix(rnorm(4), 2,2),
test <- process_input(data1, params="a", NULL, 2, 1, 100, 50, 2,
NULL, DIC=TRUE, quiet=TRUE, parallel=FALSE)
expect_inherits(test, "list")
-expect_equal(names(test), c("data", "params", "n.cores", "inits", "mcmc.info"))
+expect_equal(names(test), c("data", "params", "inits", "mcmc.info"))
# Data processing--------------------------------------------------------------
# Stuff that gets passed through unchanged
@@ -110,12 +110,12 @@ expect_error(process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=1
# when parallel=FALSE
test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2,
n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=FALSE)
-expect_true(is.null(test$n.cores))
+expect_true(is.null(test$mcmc.info$n.cores))
# when parallel=TRUE defaults to min of nchains and ncores
test <- process_input(dat, params=pars1, NULL, 2, 1, n_iter=100, n_burnin=50, 2,
n_cores=NULL, DIC=FALSE, quiet=TRUE, parallel=TRUE)
-expect_equal(test$n.cores, 2)
+expect_equal(test$mcmc.info$n.cores, 2)
avail_cores <- parallel::detectCores()
if(avail_cores > 1){
@@ -126,7 +126,7 @@ if(avail_cores > 1){
n_iter=100, n_burnin=50, 2,
n_cores=try_cores, DIC=FALSE, quiet=TRUE, parallel=TRUE)
))
- expect_equal(test$n.cores, avail_cores)
+ expect_equal(test$mcmc.info$n.cores, avail_cores)
}
# Initial value processing-----------------------------------------------------
@@ -178,3 +178,16 @@ expect_error(check_inits(inits6, n_chains=2))
# A number
expect_error(check_inits(inits7, n_chains=2))
+
+# Check exact match when inits is a function with random numbers
+set.seed(123)
+inits_fun <- function() list(a = rnorm(1), b=rnorm(1))
+#inits_ref <- jagsUI:::gen.inits(inits_fun, seed=NULL, 2)
+inits_ref <- list(list(a = -0.560475646552213, b = -0.23017748948328,
+ .RNG.name = "base::Mersenne-Twister",
+ .RNG.seed = 55143), list(a = 1.55870831414912, b = 0.070508391424576,
+ .RNG.name = "base::Mersenne-Twister", .RNG.seed = 45662))
+
+set.seed(123)
+test <- check_inits(inits_fun, 2)
+expect_equal(inits_ref, test)