diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-12-06 15:27:27 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-12-06 15:27:27 -0500 |
commit | 41f58cca65a83bfe9d1d95b843060b7b236c9b55 (patch) | |
tree | a75ff2702e33b5f29cdfbbe63456ac65561ba222 | |
parent | 3f58d60478213394ba960813976ffa47640e65b2 (diff) |
update now uses new rjags interface, reorganize
-rw-r--r-- | R/rjags_interface.R | 276 | ||||
-rw-r--r-- | R/rjags_tools.R | 33 | ||||
-rw-r--r-- | R/runmodel.R | 105 | ||||
-rw-r--r-- | R/runparallel.R | 83 | ||||
-rw-r--r-- | R/setfactories.R | 20 | ||||
-rw-r--r-- | R/setmodules.R | 24 | ||||
-rw-r--r-- | R/update.R | 97 | ||||
-rw-r--r-- | R/updatebasic.R | 86 | ||||
-rw-r--r-- | inst/tinytest/jagsbasic_ref_saved.Rds | bin | 30553 -> 30572 bytes | |||
-rw-r--r-- | inst/tinytest/test_jagsbasic.R | 20 | ||||
-rw-r--r-- | inst/tinytest/test_update.R | 18 |
11 files changed, 381 insertions, 381 deletions
diff --git a/R/rjags_interface.R b/R/rjags_interface.R new file mode 100644 index 0000000..02a0ec9 --- /dev/null +++ b/R/rjags_interface.R @@ -0,0 +1,276 @@ +# Top-level function to run analysis via rjags--------------------------------- +run_rjags <- function(data, inits, params, modfile, mcmc_info, + modules, factories, DIC, parallel, quiet, + model.object=NULL, update=FALSE){ + + #Save start time + start.time <- Sys.time() + + mc <- mcmc_info + + # Run parallel + if(parallel & mc$n.chains > 1){ + result <- run.parallel(data, inits, params, modfile, + mc$n.chains, mc$n.adapt, mc$n.iter, mc$n.burnin, mc$n.thin, + modules=modules, factories=factories, DIC=DIC, + model.object=model.object, update=update, + verbose=!quiet, n.cores=mc$n.cores) + # Move this down if not also handled in runmodel() + if(any(!result$sufficient.adapt)){ + warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt", call.=FALSE) + } + } else { + # Run non-parallel + set.modules(modules, DIC) + set.factories(factories) + + result <- run.model(modfile, data, inits, params, + mc$n.chains, mc$n.iter, mc$n.burnin, mc$n.thin, mc$n.adapt, + verbose=!quiet, model.object=model.object, update=update) + } + + result$run.date <- start.time + result$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) + + result +} + +# Setup and run rjags---------------------------------------------------------- +run.model <- function(model.file=NULL,data=NULL,inits=NULL,parameters.to.save,n.chains=NULL, + n.iter,n.burnin,n.thin,n.adapt,verbose=TRUE,model.object=NULL, + update=FALSE,parallel=FALSE,na.rm=TRUE){ + +if(verbose){pb="text"} else {pb="none"} + +if(update){ + #Recompile model + m <- model.object + if(verbose | parallel==TRUE){ + m$recompile() + } else {null <- capture.output( + m$recompile() + )} + +} else { + #Compile model + if(verbose | parallel==TRUE){ + m <- jags.model(file=model.file,data=data,inits=inits,n.chains=n.chains,n.adapt=0) + } else { + null <- capture.output( + m <- jags.model(file=model.file,data=data,inits=inits,n.chains=n.chains,n.adapt=0,quiet=TRUE) + ) + } +} + +#Adaptive phase using adapt() +total.adapt <- 0 + +if(!is.null(n.adapt)){ + if(n.adapt>0){ + if(verbose){ + cat('Adaptive phase,',n.adapt,'iterations x',n.chains,'chains','\n') + cat('If no progress bar appears JAGS has decided not to adapt','\n','\n') + sufficient.adapt <- adapt(object=m,n.iter=n.adapt,progress.bar=pb,end.adaptation=TRUE) + } else { + null <- capture.output( + sufficient.adapt <- adapt(object=m,n.iter=n.adapt,progress.bar=pb,end.adaptation=TRUE) + )} + total.adapt <- total.adapt + n.adapt + } else{ + if(verbose){cat('No adaptive period specified','\n','\n')} + #If no adaptation period specified: + #Force JAGS to not adapt (you have to allow it to adapt at least 1 iteration) + if(!update){ + if(verbose){ + sufficient.adapt <- adapt(object=m,n.iter=1,end.adaptation=TRUE) + } else { + null <- capture.output( + sufficient.adapt <- adapt(object=m,n.iter=1,end.adaptation=TRUE) + )} + } + total.adapt <- 0 + } +} else { + + maxloops <- 100 + n.adapt.iter <- 100 + + for (i in 1:maxloops){ + if(verbose){cat('Adaptive phase.....','\n')} + sufficient.adapt <- adapt(object=m,n.iter=n.adapt.iter,progress.bar='none') + total.adapt <- total.adapt + n.adapt.iter + if(i==maxloops){ + if(verbose){warning(paste("Reached max of",maxloops*n.adapt.iter,"adaption iterations; set n.adapt to > 10000"))} + null <- adapt(object=m,n.iter=1,end.adaptation = TRUE) + break + } + if(sufficient.adapt){ + null <- adapt(object=m,n.iter=1,end.adaptation = TRUE) + if(verbose){cat('Adaptive phase complete','\n','\n')} + break + } + } + +} +if(!sufficient.adapt&total.adapt!=0&verbose){warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt")} + +#Burn-in phase using update() +if(n.burnin>0){ + if(verbose){ + cat('\n','Burn-in phase,',n.burnin,'iterations x',n.chains,'chains','\n','\n') + update(object=m,n.iter=n.burnin,progress.bar=pb) + cat('\n') + } else { + null <- capture.output( + update(object=m,n.iter=n.burnin,progress.bar=pb) + )} +} else if(verbose){cat('No burn-in specified','\n','\n')} + +#Sample from posterior using coda.samples() +if(verbose){ + cat('Sampling from joint posterior,',(n.iter-n.burnin),'iterations x',n.chains,'chains','\n','\n') + samples <- coda.samples(model=m,variable.names=parameters.to.save,n.iter=(n.iter-n.burnin),thin=n.thin, + na.rm=na.rm, progress.bar=pb) + cat('\n') +} else { + null <- capture.output( + samples <- coda.samples(model=m,variable.names=parameters.to.save,n.iter=(n.iter-n.burnin),thin=n.thin, + na.rm=na.rm, progress.bar=pb) + )} + +return(list(m=m,samples=samples,total.adapt=total.adapt,sufficient.adapt=sufficient.adapt)) +} + +# Setup parallel and run rjags------------------------------------------------- +run.parallel <- function(data=NULL,inits=NULL,parameters.to.save,model.file=NULL,n.chains,n.adapt,n.iter,n.burnin,n.thin, + modules,factories,DIC,model.object=NULL,update=FALSE,verbose=TRUE,n.cores=NULL) { + +#Save current library paths +current.libpaths <- .libPaths() + +#Set up clusters +cl = makeCluster(n.cores) +on.exit(stopCluster(cl)) +clusterExport(cl = cl, ls(), envir = environment()) +clusterEvalQ(cl,.libPaths(current.libpaths)) + +if(verbose){ +cat('Beginning parallel processing using',n.cores,'cores. Console output will be suppressed.\n')} + +#Function called in each core +jags.clust <- function(i){ + +#Load modules +set.modules(modules,DIC) +set.factories(factories) + +if(update){ + #Recompile model + cluster.mod <- model.object[[i]] + + #Run model + rjags.output <- run.model(model.file=NULL,data=NULL,inits=NULL,parameters.to.save,n.chains=1,n.iter,n.burnin=0,n.thin,n.adapt, + verbose=FALSE,model.object=cluster.mod,update=TRUE,parallel=TRUE, na.rm=FALSE) + +} else { + + #Set initial values for cluster + cluster.inits <- inits[[i]] + + #Run model + + rjags.output <- run.model(model.file,data,inits=cluster.inits,parameters.to.save,n.chains=1,n.iter, + n.burnin,n.thin,n.adapt,verbose=FALSE,parallel=TRUE, na.rm=FALSE) + +} + +return(list(samp=rjags.output$samples[[1]],mod=rjags.output$m,total.adapt=rjags.output$total.adapt,sufficient.adapt=rjags.output$sufficient.adapt)) + +} + +#Do analysis +par <- clusterApply(cl=cl,x=1:n.chains,fun=jags.clust) + +#Create empty lists +out <- samples <- model <- list() +total.adapt <- sufficient.adapt <- vector(length=n.chains) + +#Save samples and model objects from each cluster +total.adapt <- sapply(par, function(x) x[[3]]) +starts <- sapply(par, function(x) stats::start(x$samp)) +ends <- sapply(par, function(x) stats::end(x$samp)) +nsamp <- sapply(par, function(x) coda::niter(x$samp)) +thins <- sapply(par, function(x) coda::thin(x$samp)) +#mc_start <- max(total.adapt) + n.burnin + n.thin +stopifnot(all(nsamp == nsamp[1])) +stopifnot(all(thins == n.thin)) +for (i in 1:n.chains){ + samples[[i]] <- coda::mcmc(par[[i]]$samp,start=max(starts),thin=n.thin) + model[[i]] <- par[[i]]$m + sufficient.adapt[i] <- par[[i]]$sufficient.adapt +} +out$samples <- as.mcmc.list(samples) +# Remove columns with all NA +try({ + all_na <- apply(as.matrix(out$samples),2, function(x) all(is.na(x))) + out$samples <- out$samples[,!all_na,drop=FALSE] +}) +out$model <- model +out$total.adapt <- total.adapt +out$sufficient.adapt <- sufficient.adapt +names(out$model) <- sapply(1:length(out$model),function(i){paste('cluster',i,sep="")}) + +if(verbose){ +cat('\nParallel processing completed.\n\n') +} + +return(out) + +} + + +# Set factories---------------------------------------------------------------- +set.factories <- function(factories){ + + if(!is.null(factories)){ + for (i in 1:length(factories)){ + + split <- strsplit(factories[i],'\\s')[[1]] + + #Check if requested factory is available + faclist <- as.character(list.factories(split[2])[,1]) + if(split[1]%in%faclist){ + + null <- set.factory(split[1],split[2],split[3]) + + } else{stop(paste('Requested factory',split[1],'is not available. Check that appropriate modules are loaded.'))} + + } + } + +} + +# Set modules------------------------------------------------------------------ +set.modules <- function(modules,DIC){ + + #Load/unload appropriate modules (besides dic) + called.set <- c('basemod','bugs',modules) + current.set <- list.modules() + + load.set <- called.set[!called.set%in%current.set] + unload.set <- current.set[!current.set%in%called.set] + + if(length(load.set)>0){ + for (i in 1:length(load.set)){ + load.module(load.set[i],quiet=TRUE) + } + } + if(length(unload.set)>0){ + for (i in 1:length(unload.set)){ + unload.module(unload.set[i],quiet=TRUE) + } + } + if(DIC){ + load.module("dic",quiet=TRUE) + } +} diff --git a/R/rjags_tools.R b/R/rjags_tools.R deleted file mode 100644 index c39c592..0000000 --- a/R/rjags_tools.R +++ /dev/null @@ -1,33 +0,0 @@ -run_rjags <- function(data, inits, params, modfile, mcmc_info, - modules, factories, DIC, parallel, quiet){ - - #Save start time - start.time <- Sys.time() - - mc <- mcmc_info - - # Run parallel - if(parallel & mc$n.chains > 1){ - result <- run.parallel(data, inits, params, modfile, - mc$n.chains, mc$n.adapt, mc$n.iter, mc$n.burnin, mc$n.thin, - modules=modules, factories=factories, DIC=DIC, - verbose=!quiet, n.cores=mc$n.cores) - # Move this down if not also handled in runmodel() - if(any(!result$sufficient.adapt)){ - warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt", call.=FALSE) - } - } else { - # Run non-parallel - set.modules(modules, DIC) - set.factories(factories) - - result <- run.model(modfile, data, inits, params, - mc$n.chains, mc$n.iter, mc$n.burnin, mc$n.thin, mc$n.adapt, - verbose=!quiet) - } - - result$run.date <- start.time - result$elapsed.mins <- round(as.numeric(Sys.time()-start.time,units="mins"),digits=3) - - result -} diff --git a/R/runmodel.R b/R/runmodel.R deleted file mode 100644 index ad5c276..0000000 --- a/R/runmodel.R +++ /dev/null @@ -1,105 +0,0 @@ - -run.model <- function(model.file=NULL,data=NULL,inits=NULL,parameters.to.save,n.chains=NULL, - n.iter,n.burnin,n.thin,n.adapt,verbose=TRUE,model.object=NULL, - update=FALSE,parallel=FALSE,na.rm=TRUE){ - -if(verbose){pb="text"} else {pb="none"} - -if(update){ - #Recompile model - m <- model.object - if(verbose | parallel==TRUE){ - m$recompile() - } else {null <- capture.output( - m$recompile() - )} - -} else { - #Compile model - if(verbose | parallel==TRUE){ - m <- jags.model(file=model.file,data=data,inits=inits,n.chains=n.chains,n.adapt=0) - } else { - null <- capture.output( - m <- jags.model(file=model.file,data=data,inits=inits,n.chains=n.chains,n.adapt=0,quiet=TRUE) - ) - } -} - -#Adaptive phase using adapt() -total.adapt <- 0 - -if(!is.null(n.adapt)){ - if(n.adapt>0){ - if(verbose){ - cat('Adaptive phase,',n.adapt,'iterations x',n.chains,'chains','\n') - cat('If no progress bar appears JAGS has decided not to adapt','\n','\n') - sufficient.adapt <- adapt(object=m,n.iter=n.adapt,progress.bar=pb,end.adaptation=TRUE) - } else { - null <- capture.output( - sufficient.adapt <- adapt(object=m,n.iter=n.adapt,progress.bar=pb,end.adaptation=TRUE) - )} - total.adapt <- total.adapt + n.adapt - } else{ - if(verbose){cat('No adaptive period specified','\n','\n')} - #If no adaptation period specified: - #Force JAGS to not adapt (you have to allow it to adapt at least 1 iteration) - if(!update){ - if(verbose){ - sufficient.adapt <- adapt(object=m,n.iter=1,end.adaptation=TRUE) - } else { - null <- capture.output( - sufficient.adapt <- adapt(object=m,n.iter=1,end.adaptation=TRUE) - )} - } - total.adapt <- 0 - } -} else { - - maxloops <- 100 - n.adapt.iter <- 100 - - for (i in 1:maxloops){ - if(verbose){cat('Adaptive phase.....','\n')} - sufficient.adapt <- adapt(object=m,n.iter=n.adapt.iter,progress.bar='none') - total.adapt <- total.adapt + n.adapt.iter - if(i==maxloops){ - if(verbose){warning(paste("Reached max of",maxloops*n.adapt.iter,"adaption iterations; set n.adapt to > 10000"))} - null <- adapt(object=m,n.iter=1,end.adaptation = TRUE) - break - } - if(sufficient.adapt){ - null <- adapt(object=m,n.iter=1,end.adaptation = TRUE) - if(verbose){cat('Adaptive phase complete','\n','\n')} - break - } - } - -} -if(!sufficient.adapt&total.adapt!=0&verbose){warning("JAGS reports adaptation was incomplete. Consider increasing n.adapt")} - -#Burn-in phase using update() -if(n.burnin>0){ - if(verbose){ - cat('\n','Burn-in phase,',n.burnin,'iterations x',n.chains,'chains','\n','\n') - update(object=m,n.iter=n.burnin,progress.bar=pb) - cat('\n') - } else { - null <- capture.output( - update(object=m,n.iter=n.burnin,progress.bar=pb) - )} -} else if(verbose){cat('No burn-in specified','\n','\n')} - -#Sample from posterior using coda.samples() -if(verbose){ - cat('Sampling from joint posterior,',(n.iter-n.burnin),'iterations x',n.chains,'chains','\n','\n') - samples <- coda.samples(model=m,variable.names=parameters.to.save,n.iter=(n.iter-n.burnin),thin=n.thin, - na.rm=na.rm, progress.bar=pb) - cat('\n') -} else { - null <- capture.output( - samples <- coda.samples(model=m,variable.names=parameters.to.save,n.iter=(n.iter-n.burnin),thin=n.thin, - na.rm=na.rm, progress.bar=pb) - )} - -return(list(m=m,samples=samples,total.adapt=total.adapt,sufficient.adapt=sufficient.adapt)) -} diff --git a/R/runparallel.R b/R/runparallel.R deleted file mode 100644 index b60b196..0000000 --- a/R/runparallel.R +++ /dev/null @@ -1,83 +0,0 @@ - -run.parallel <- function(data=NULL,inits=NULL,parameters.to.save,model.file=NULL,n.chains,n.adapt,n.iter,n.burnin,n.thin, - modules,factories,DIC,model.object=NULL,update=FALSE,verbose=TRUE,n.cores=NULL) { - -#Save current library paths -current.libpaths <- .libPaths() - -#Set up clusters -cl = makeCluster(n.cores) -on.exit(stopCluster(cl)) -clusterExport(cl = cl, ls(), envir = environment()) -clusterEvalQ(cl,.libPaths(current.libpaths)) - -if(verbose){ -cat('Beginning parallel processing using',n.cores,'cores. Console output will be suppressed.\n')} - -#Function called in each core -jags.clust <- function(i){ - -#Load modules -set.modules(modules,DIC) -set.factories(factories) - -if(update){ - #Recompile model - cluster.mod <- model.object[[i]] - - #Run model - rjags.output <- run.model(model.file=NULL,data=NULL,inits=NULL,parameters.to.save,n.chains=1,n.iter,n.burnin=0,n.thin,n.adapt, - verbose=FALSE,model.object=cluster.mod,update=TRUE,parallel=TRUE, na.rm=FALSE) - -} else { - - #Set initial values for cluster - cluster.inits <- inits[[i]] - - #Run model - - rjags.output <- run.model(model.file,data,inits=cluster.inits,parameters.to.save,n.chains=1,n.iter, - n.burnin,n.thin,n.adapt,verbose=FALSE,parallel=TRUE, na.rm=FALSE) - -} - -return(list(samp=rjags.output$samples[[1]],mod=rjags.output$m,total.adapt=rjags.output$total.adapt,sufficient.adapt=rjags.output$sufficient.adapt)) - -} - -#Do analysis -par <- clusterApply(cl=cl,x=1:n.chains,fun=jags.clust) - -#Create empty lists -out <- samples <- model <- list() -total.adapt <- sufficient.adapt <- vector(length=n.chains) - -#Save samples and model objects from each cluster -total.adapt <- sapply(par, function(x) x[[3]]) -mc_start <- max(total.adapt) + n.burnin + n.thin -for (i in 1:n.chains){ - samples[[i]] <- coda::mcmc(par[[i]][[1]],start=mc_start,thin=n.thin) - model[[i]] <- par[[i]][[2]] - sufficient.adapt[i] <- par[[i]][[4]] -} -out$samples <- as.mcmc.list(samples) -# Remove columns with all NA -try({ - all_na <- apply(as.matrix(out$samples),2, function(x) all(is.na(x))) - out$samples <- out$samples[,!all_na,drop=FALSE] -}) -out$model <- model -out$total.adapt <- total.adapt -out$sufficient.adapt <- sufficient.adapt -names(out$model) <- sapply(1:length(out$model),function(i){paste('cluster',i,sep="")}) - -if(verbose){ -cat('\nParallel processing completed.\n\n') -} - -return(out) - -} - - - diff --git a/R/setfactories.R b/R/setfactories.R deleted file mode 100644 index 26e0eb8..0000000 --- a/R/setfactories.R +++ /dev/null @@ -1,20 +0,0 @@ - -set.factories <- function(factories){ - - if(!is.null(factories)){ - for (i in 1:length(factories)){ - - split <- strsplit(factories[i],'\\s')[[1]] - - #Check if requested factory is available - faclist <- as.character(list.factories(split[2])[,1]) - if(split[1]%in%faclist){ - - null <- set.factory(split[1],split[2],split[3]) - - } else{stop(paste('Requested factory',split[1],'is not available. Check that appropriate modules are loaded.'))} - - } - } - -}
\ No newline at end of file diff --git a/R/setmodules.R b/R/setmodules.R deleted file mode 100644 index 06e6257..0000000 --- a/R/setmodules.R +++ /dev/null @@ -1,24 +0,0 @@ - -set.modules <- function(modules,DIC){ - - #Load/unload appropriate modules (besides dic) - called.set <- c('basemod','bugs',modules) - current.set <- list.modules() - - load.set <- called.set[!called.set%in%current.set] - unload.set <- current.set[!current.set%in%called.set] - - if(length(load.set)>0){ - for (i in 1:length(load.set)){ - load.module(load.set[i],quiet=TRUE) - } - } - if(length(unload.set)>0){ - for (i in 1:length(unload.set)){ - unload.module(unload.set[i],quiet=TRUE) - } - } - if(DIC){ - load.module("dic",quiet=TRUE) - } -}
\ No newline at end of file @@ -1,99 +1,74 @@ -update.jagsUI <- function(object, parameters.to.save=NULL, n.adapt=NULL, n.iter, n.thin=NULL, +update.jagsUI <- function(object, parameters.to.save=NULL, + n.adapt=NULL, n.iter, n.thin=NULL, modules=c('glm'), factories=NULL, DIC=NULL,codaOnly=FALSE, verbose=TRUE, ...){ - mod <- object$model - #Get list of parameters to save - if(is.null(parameters.to.save)){parameters <- object$parameters - } else {parameters <- parameters.to.save} + if(is.null(parameters.to.save)) parameters.to.save <- object$parameters #Set up DIC monitoring - if(is.null(DIC)){ - DIC <- object$calc.DIC + if(is.null(DIC)) DIC <- object$calc.DIC + if(DIC & (!'deviance' %in% parameters.to.save)){ + parameters.to.save <- c(parameters.to.save, 'deviance') + } else if(!DIC & 'deviance' %in% parameters.to.save){ + parameters.to.save <- parameters.to.save[parameters.to.save != 'deviance'] } - if(DIC&!'deviance'%in%parameters){parameters <- c(parameters,'deviance') - } else if(!DIC&'deviance'%in%parameters){parameters <- parameters[parameters!='deviance']} - - #Get thin rate - if(is.null(n.thin)){n.thin <- object$mcmc.info$n.thin} - - start.time <- Sys.time() - - if(object$parallel){ + # Update mcmc info + mcmc.info <- object$mcmc.info + mcmc.info$n.iter <- n.iter + mcmc.info$n.burnin <- 0 + mcmc.info$n.adapt <- n.adapt + if(!is.null(n.thin)) mcmc.info$n.thin <- n.thin + + # Run JAGS via rjags + rjags_out <- run_rjags(data=NULL, inits=NULL, parameters.to.save, modfile=NULL, + mcmc.info, modules, factories, DIC, object$parallel, !verbose, + model.object = object$model, update=TRUE) - par <- run.parallel(data=NULL,inits=NULL,parameters.to.save=parameters,model.file=NULL,n.chains=object$mcmc.info$n.chains - ,n.adapt=n.adapt,n.iter=n.iter,n.burnin=0,n.thin=n.thin,modules=modules,factories=factories, - DIC=DIC,model.object=mod,update=TRUE,verbose=verbose,n.cores=object$mcmc.info$n.cores) - samples <- par$samples - m <- par$model - - } else { - - #Set modules - set.modules(modules,DIC) - set.factories(factories) - - rjags.output <- run.model(model.file=NULL,data=NULL,inits=NULL,parameters.to.save=parameters, - n.chains=object$mcmc.info$n.chains,n.iter,n.burnin=0,n.thin,n.adapt, - model.object=mod,update=TRUE,verbose=verbose) - samples <- rjags.output$samples - m <- rjags.output$m - - } - - end.time <- Sys.time() - time <- round(as.numeric(end.time-start.time,units="mins"),digits=3) - date <- start.time - #Reorganize JAGS output to match input parameter order - samples <- order_samples(samples, parameters) - + samples <- order_samples(rjags_out$samples, parameters.to.save) #Run process output output <- process_output(samples, coda_only = codaOnly, DIC, quiet = !verbose) + # Fallback if output processing fails if(is.null(output)){ - output <- list() - output$samples <- samples - output$model <- m + output <- list(samples = samples, model = rjags_out$m) output$n.cores <- object$mcmc.info$n.cores class(output) <- 'jagsUIbasic' return(output) } #Save other information to output object - output$samples <- samples - - output$modfile <- object$modfile - + output$samples <- samples + output$modfile <- object$modfile #If user wants to save input data/inits if(!is.null(object$inits)){ output$inits <- object$inits output$data <- object$data } - - output$parameters <- parameters - output$model <- m + output$parameters <- parameters.to.save + output$model <- rjags_out$m output$mcmc.info <- object$mcmc.info output$mcmc.info$n.burnin <- object$mcmc.info$n.iter output$mcmc.info$n.iter <- n.iter + output$mcmc.info$n.burnin - output$mcmc.info$n.thin <- n.thin - output$mcmc.info$n.samples <- (output$mcmc.info$n.iter-output$mcmc.info$n.burnin) / n.thin * output$mcmc.info$n.chains - output$mcmc.info$elapsed.mins <- time - output$run.date <- date - output$random.seed <- object$random.seed + output$mcmc.info$n.thin <- mcmc.info$n.thin + output$mcmc.info$n.samples <- coda::niter(samples) * output$mcmc.info$n.chains + output$mcmc.info$elapsed.mins <- rjags_out$elapsed.mins + output$run.date <- rjags_out$run.date output$parallel <- object$parallel output$bugs.format <- object$bugs.format output$calc.DIC <- DIC #Keep a record of how many times model has been updated - if(is.null(object$update.count)){output$update.count <- 1 - } else {output$update.count <- object$update.count + 1} + if(is.null(object$update.count)){ + output$update.count <- 1 + } else { + output$update.count <- object$update.count + 1 + } #Classify final output object class(output) <- 'jagsUI' - return(output) - + return(output) } diff --git a/R/updatebasic.R b/R/updatebasic.R index 0a15d17..97fd512 100644 --- a/R/updatebasic.R +++ b/R/updatebasic.R @@ -1,61 +1,47 @@ -update.jagsUIbasic <- function(object, parameters.to.save=NULL, n.adapt=NULL, n.iter, n.thin=NULL, - modules=c('glm'), factories=NULL, DIC=NULL, verbose=TRUE, ...){ +update.jagsUIbasic <- function(object, parameters.to.save=NULL, + n.adapt=NULL, n.iter, n.thin=NULL, + modules=c('glm'), factories=NULL, + DIC=NULL, verbose=TRUE, ...){ - mod <- object$model - n.chains <- length(object$samples) - n.cores <- object$n.cores - + # Set up parameters if(is.null(parameters.to.save)){ - params.temp <- colnames(object$samples[[1]]) - parameters <- unique(sapply(strsplit(params.temp, "\\["), "[", 1)) - } else {parameters <- parameters.to.save} + params_long <- colnames(object$samples[[1]]) + parameters.to.save <- unique(sapply(strsplit(params_long, "\\["), "[", 1)) + } #Set up DIC monitoring if(is.null(DIC)){ - if('deviance'%in%parameters){ - DIC=TRUE - } else {DIC=FALSE} - } else{ - if(DIC&!'deviance'%in%parameters){parameters <- c(parameters,'deviance') - } else if(!DIC&'deviance'%in%parameters){parameters <- parameters[parameters!='deviance']} + DIC <- 'deviance' %in% parameters.to.save + } else { + if(DIC & (!'deviance' %in% parameters.to.save)){ + parameters.to.save <- c(parameters.to.save, 'deviance') + } else if(!DIC & 'deviance' %in% parameters.to.save){ + parameters.to.save <- parameters.to.save[parameters.to.save != 'deviance'] + } } + + # Set up MCMC info + mcmc.info <- list(n.chains = length(object$samples), n.adapt = n.adapt, + n.iter = n.iter, n.burnin = 0, + n.thin = ifelse(is.null(n.thin), thin(object$samples), n.thin), + n.cores = object$n.cores) + + parallel <- names(object$model[1]) == "cluster1" + + # Run JAGS via rjags + rjags_out <- run_rjags(data=NULL, inits=NULL, parameters.to.save, modfile=NULL, + mcmc.info, modules, factories, DIC, parallel, !verbose, + model.object = object$model, update=TRUE) + + # Report time + if(verbose) cat('MCMC took', rjags_out$elapsed.min, 'minutes.\n') - if(is.null(n.thin)){n.thin <- thin(object$samples)} - - start.time <- Sys.time() - - if(names(object$model[1])=='cluster1'){ - - par <- run.parallel(data=NULL,inits=NULL,parameters.to.save=parameters,model.file=NULL,n.chains=n.chains - ,n.adapt=n.adapt,n.iter=n.iter,n.burnin=0,n.thin=n.thin,modules=modules,factories=factories, - DIC=DIC,model.object=mod,update=TRUE,verbose=verbose,n.cores=n.cores) - samples <- par$samples - m <- par$model - - } else { - - #Set modules - set.modules(modules,DIC) - set.factories(factories) - - rjags.output <- run.model(model.file=NULL,data=NULL,inits=NULL,parameters.to.save=parameters, - n.chains=object$mcmc.info$n.chains,n.iter,n.burnin=0,n.thin,n.adapt, - model.object=mod,update=TRUE,verbose=verbose) - samples <- rjags.output$samples - m <- rjags.output$m - } - - samples <- order_samples(samples, parameters) - - end.time <- Sys.time() - time <- round(as.numeric(end.time-start.time,units="mins"),digits=3) - if(verbose){cat('MCMC took',time,'minutes.\n')} - - output <- list(samples=samples,model=m,n.cores=n.cores) - + # Create output object + output <- list(samples = order_samples(rjags_out$samples, parameters.to.save), + model = rjags_out$m, + n.cores = object$n.cores) class(output) <- 'jagsUIbasic' - return(output) - + return(output) } diff --git a/inst/tinytest/jagsbasic_ref_saved.Rds b/inst/tinytest/jagsbasic_ref_saved.Rds Binary files differindex 1aa7e55..f8ca287 100644 --- a/inst/tinytest/jagsbasic_ref_saved.Rds +++ b/inst/tinytest/jagsbasic_ref_saved.Rds diff --git a/inst/tinytest/test_jagsbasic.R b/inst/tinytest/test_jagsbasic.R index c646aa0..b4b33a6 100644 --- a/inst/tinytest/test_jagsbasic.R +++ b/inst/tinytest/test_jagsbasic.R @@ -30,6 +30,7 @@ ref <- readRDS("jagsbasic_reference_fit.Rds") expect_identical(out, ref) # Saved model and reordered parameter names------------------------------------ +set.seed(123) params <- c('beta', 'alpha', 'sigma', 'mu') out <- jags.basic(data = data, inits = inits, parameters.to.save = params, model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100, @@ -43,7 +44,6 @@ expect_identical(out, ref) # Update----------------------------------------------------------------------- out2 <- update(out, n.iter=100, n.thin = 2, verbose=FALSE) expect_equal(nrow(out2$samples[[1]]), 50) - ref <- readRDS('jagsbasic_ref_update.Rds') expect_identical(names(out2), names(ref)) out2$model <- ref$model @@ -58,12 +58,22 @@ expect_error(jags.basic(data = data, inits = inits, parameters.to.save = params, at_home <- identical( Sys.getenv("AT_HOME"), "TRUE" ) if(parallel::detectCores() > 1 & at_home){ set.seed(123) + params <- c('beta', 'alpha', 'sigma', 'mu') out <- jags.basic(data = data, inits = inits, parameters.to.save = params, - model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100, - n.burnin = 50, n.thin = 2, verbose=FALSE, parallel=TRUE) + model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100, + n.burnin = 50, n.thin = 2, verbose=FALSE, save.model=TRUE, parallel=TRUE) + ref <- readRDS("jagsbasic_ref_saved.Rds") + + out$n.cores <- NULL + expect_identical(names(out), names(ref)) + out$model <- ref$model + expect_identical(out, ref) - ref <- readRDS("jagsbasic_reference_fit.Rds") - expect_identical(out[-c(17,18,20:22)], ref[-c(17,18,20:22)]) + out2 <- update(out, n.iter=100, n.thin = 2, verbose=FALSE) + ref <- readRDS('jagsbasic_ref_update.Rds') + expect_identical(names(out2), names(ref)) + out2$model <- ref$model + expect_equal(out2, ref) } # Verbose--------------------------------------------------------------------- diff --git a/inst/tinytest/test_update.R b/inst/tinytest/test_update.R index 60b05c3..42202be 100644 --- a/inst/tinytest/test_update.R +++ b/inst/tinytest/test_update.R @@ -67,3 +67,21 @@ expect_message(out2 <- update(out, n.iter=100, n.thin=2, verbose=FALSE, expect_inherits(out2, "jagsUIbasic") expect_equal(coda::varnames(out2$samples), c("alpha","deviance")) expect_equal(names(out2), c("samples", "model")) + +# Parallel--------------------------------------------------------------------- +at_home <- identical( Sys.getenv("AT_HOME"), "TRUE" ) +if(parallel::detectCores() > 1 & at_home){ + set.seed(123) + out <- jags(data = data, inits = inits, parameters.to.save = params, + model.file = modfile, n.chains = 3, n.adapt = 100, n.iter = 100, + n.burnin = 50, n.thin = 2, verbose=FALSE, parallel=TRUE) + + out2 <- update(out, n.iter=100, n.thin=2, verbose=FALSE) + ref <- readRDS("update_ref.Rds") + ref$parallel <- TRUE + out2$mcmc.info$n.cores <- NULL + ref$mcmc.info$sufficient.adapt <- out2$mcmc.info$sufficient.adapt + ref$mcmc.info$n.adapt <- out2$mcmc.info$n.adapt + out2$mcmc.info$elapsed.mins <- ref$mcmc.inf$elapsed.mins + expect_identical(out2[-c(17,19,21)], ref[-c(17,19,21)]) +} |