diff options
author | Ken Kellner <ken@kenkellner.com> | 2023-12-02 18:52:15 -0500 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2023-12-02 18:52:15 -0500 |
commit | ef86be79b30cba60f3c9ba13c91f1a0ef94290ab (patch) | |
tree | 35d4a8d0b27322d3792160d9d74efe64f494d9e3 | |
parent | 79ea4216c321ee9d6c2bde3e91c5dfc89edd916d (diff) |
Update mcmc tools
-rw-r--r-- | R/densityplot.R | 2 | ||||
-rw-r--r-- | R/mcmc_tools.R | 105 | ||||
-rw-r--r-- | R/ppcheck.R | 4 | ||||
-rw-r--r-- | R/traceplot.R | 2 | ||||
-rw-r--r-- | R/whiskerplot.R | 2 | ||||
-rw-r--r-- | inst/tinytest/coda_samples.Rds | bin | 0 -> 19301 bytes | |||
-rw-r--r-- | inst/tinytest/one_sample.Rds | bin | 0 -> 307 bytes | |||
-rw-r--r-- | inst/tinytest/test_mcmc_tools.R | 99 | ||||
-rw-r--r-- | tests/tinytest.R | 4 |
9 files changed, 167 insertions, 51 deletions
diff --git a/R/densityplot.R b/R/densityplot.R index 0c6f987..b350f0a 100644 --- a/R/densityplot.R +++ b/R/densityplot.R @@ -23,7 +23,7 @@ densityplot <- function(x, parameters=NULL, layout=NULL, ask=NULL){ param_density <- function(x, parameter, m_labels=FALSE){ #Get samples - vals <- mcmc_to_mat(x$samples, parameter) + vals <- mcmc_to_mat(x$samples[,parameter]) # Get bandwidth, one value for all chains bw <- mean(apply(vals, 2, stats::bw.nrd0)) diff --git a/R/mcmc_tools.R b/R/mcmc_tools.R index 205631d..b8e48a5 100644 --- a/R/mcmc_tools.R +++ b/R/mcmc_tools.R @@ -1,40 +1,22 @@ -#Functions for manipulating and extracting info from mcmc.list-class objects -#from package rjags/coda +#------------------------------------------------------------------------------ +#Get names of parameters from an mcmc.list +#If simplify=T, also drop brackets/indices +param_names <- function(mcmc_list, simplify=FALSE){ + out <- coda::varnames(mcmc_list) + if(simplify) out <- strip_params(out, unique=TRUE) + out +} -# This is a subset of the functions in mcmc_tools in devel version 1.5.1.9024 -###------------------------------------------------------------------------------ -#Remove brackets and indices from parameter names in mcmc.list +#------------------------------------------------------------------------------ strip_params <- function(params_raw, unique=FALSE){ - params_strip <- sapply(strsplit(params_raw,'[', fixed=T),'[',1) + params_strip <- sapply(strsplit(params_raw,'[', fixed=TRUE),'[',1) if(unique) return( unique(params_strip) ) params_strip } -#------------------------------------------------------------------------------ -###------------------------------------------------------------------------------ -#Identify which columns in mcmc.list object correspond to a given -#parameter name (useful for non-scalar parameters) -which_params <- function(param, params_raw){ - params_strip <- strip_params(params_raw) - if( ! param %in% params_strip ){ - return(NULL) - } - which(params_strip == param) -} -#------------------------------------------------------------------------------ -###------------------------------------------------------------------------------ -#Get names of parameters from an mcmc.list -#If simplify=T, also drop brackets/indices -param_names <- function(mcmc_list, simplify=FALSE){ - raw <- colnames(mcmc_list[[1]]) - if(!simplify) return(raw) - strip_params(raw, unique=T) -} #------------------------------------------------------------------------------ - -###------------------------------------------------------------------------------ #Match parameter name to scalar or array versions of parameter name match_params <- function(params, params_raw){ unlist(lapply(params, function(x){ @@ -43,28 +25,59 @@ match_params <- function(params, params_raw){ params_raw[which_params(x, params_raw)] })) } + + #------------------------------------------------------------------------------ +#Reorder output samples from coda to match input parameter order +order_samples <- function(samples, params){ + tryCatch({ + matched <- match_params(params, param_names(samples)) + samples[,matched,drop=FALSE] + }, error = function(e){ + message(paste0("Caught error re-ordering samples:\n",e,"\n")) + samples + }) +} -###------------------------------------------------------------------------------ -#Subset cols of mcmc.list (simple version of [.mcmc.list method) -select_cols <- function(mcmc_list, col_inds){ - out <- lapply(1:length(mcmc_list), FUN=function(x){ - mcmc_element <- mcmc_list[[x]][,col_inds,drop=FALSE] - attr(mcmc_element,'mcpar') <- attr(mcmc_list[[x]], 'mcpar') - class(mcmc_element) <- 'mcmc' - mcmc_element - }) - class(out) <- 'mcmc.list' - out + +#------------------------------------------------------------------------------ +#Identify which columns in mcmc.list object correspond to a given +#parameter name (useful for non-scalar parameters) +which_params <- function(param, params_raw){ + params_strip <- strip_params(params_raw) + if( ! param %in% params_strip ){ + return(NULL) + } + which(params_strip == param) } + + #------------------------------------------------------------------------------ +#Remove parameters from list of params +subset_params <- function(samples, exclude=NULL){ + all_params <- param_names(samples) + if(is.null(exclude)) return(all_params) + params_strip <- strip_params(all_params) + ind <- unlist(sapply(exclude, which_params, all_params)) + all_params[-ind] +} + -###------------------------------------------------------------------------------ -#Convert one parameter in mcmc.list to matrix, n_iter * n_chains -mcmc_to_mat <- function(samples, param){ - psamples <- select_cols(samples, param) - n_chain <- length(samples) - n_iter <- nrow(samples[[1]]) - matrix(unlist(psamples), nrow=n_iter, ncol=n_chain) +#------------------------------------------------------------------------------ +mcmc_to_mat <- function(mcmc_list){ + stopifnot(coda::nvar(mcmc_list) == 1) + matrix(unlist(mcmc_list), + nrow=coda::niter(mcmc_list), ncol=coda::nchain(mcmc_list)) } + #------------------------------------------------------------------------------ +#Extract index values inside brackets from a non-scalar parameter +#param is the "base" name of the parameter and params_raw is a vector of +#strings that contain brackets +get_inds <- function(param, params_raw){ + inds_raw <- sub(paste(param,'[',sep=''),'', params_raw,fixed=T) + inds_raw <- sub(']','', inds_raw, fixed=T) + inds_raw <- strsplit(inds_raw,',',fixed=T) + inds <- as.integer(unlist(inds_raw)) + matrix(inds, byrow=T, ncol=length(inds_raw[[1]])) +} diff --git a/R/ppcheck.R b/R/ppcheck.R index 8229b7b..aae4408 100644 --- a/R/ppcheck.R +++ b/R/ppcheck.R @@ -11,8 +11,8 @@ pp.check <- function(x, observed, simulated, stop("Simulated parameter not found in output") } - obs <- c(mcmc_to_mat(x$samples, observed)) - sim <- c(mcmc_to_mat(x$samples, simulated)) + obs <- c(mcmc_to_mat(x$samples[,observed])) + sim <- c(mcmc_to_mat(x$samples[,simulated])) bpval <- mean(sim > obs) plotrange <- range(obs, sim) diff --git a/R/traceplot.R b/R/traceplot.R index 4a7c8bd..755d3aa 100644 --- a/R/traceplot.R +++ b/R/traceplot.R @@ -25,7 +25,7 @@ traceplot <- function(x, parameters=NULL, Rhat_min=NULL, param_trace <- function(x, parameter, m_labels=FALSE){ #Get samples and Rhat values - vals <- mcmc_to_mat(x$samples, parameter) + vals <- mcmc_to_mat(x$samples[, parameter]) Rhat <- sprintf("%.3f",round(x$summary[parameter, 'Rhat'],3)) #Draw plot diff --git a/R/whiskerplot.R b/R/whiskerplot.R index 5b2bda9..8b09bf1 100644 --- a/R/whiskerplot.R +++ b/R/whiskerplot.R @@ -20,7 +20,7 @@ whiskerplot <- function(x,parameters,quantiles=c(0.025,0.975), #Calculate means and CIs post_stats <- sapply(parameters, function(i){ - sims <- mcmc_to_mat(x$samples, i) + sims <- mcmc_to_mat(x$samples[, i]) c(mean(sims,na.rm=TRUE), stats::quantile(sims,na.rm=TRUE,quantiles)) }) diff --git a/inst/tinytest/coda_samples.Rds b/inst/tinytest/coda_samples.Rds Binary files differnew file mode 100644 index 0000000..db70638 --- /dev/null +++ b/inst/tinytest/coda_samples.Rds diff --git a/inst/tinytest/one_sample.Rds b/inst/tinytest/one_sample.Rds Binary files differnew file mode 100644 index 0000000..5274d8c --- /dev/null +++ b/inst/tinytest/one_sample.Rds diff --git a/inst/tinytest/test_mcmc_tools.R b/inst/tinytest/test_mcmc_tools.R new file mode 100644 index 0000000..a85afe7 --- /dev/null +++ b/inst/tinytest/test_mcmc_tools.R @@ -0,0 +1,99 @@ +# test that param_names returns correct names--------------------------------- +param_names <- jagsUI:::param_names +samples <- readRDS('coda_samples.Rds') +expect_equal(param_names(samples), + c("alpha", "beta", "sigma", "mu[1]", "mu[2]", "mu[3]", "mu[4]", + "mu[5]", "mu[6]", "mu[7]", "mu[8]", "mu[9]", "mu[10]", "mu[11]", + "mu[12]", "mu[13]", "mu[14]", "mu[15]", "mu[16]", "kappa[1,1,1]", + "kappa[2,1,1]", "kappa[1,2,1]", "kappa[2,2,1]", "kappa[1,1,2]", + "kappa[2,1,2]", "kappa[1,2,2]", "kappa[2,2,2]", "deviance")) +expect_equal(param_names(samples,simplify=T), + c('alpha','beta','sigma','mu','kappa','deviance')) + + +# test that strip_params removes brackets and indices-------------------------- +strip_params <- jagsUI:::strip_params +params_raw <- c('alpha','beta[1]','beta[2]','gamma[1,2]','kappa[1,2,3]') +expect_equal(strip_params(params_raw), + c('alpha','beta','beta','gamma','kappa')) +expect_equal(strip_params(params_raw,unique=T), + c('alpha','beta','gamma','kappa')) + + +# test that match_param identifies correct set of params----------------------- +match_params <- jagsUI:::match_params +params_raw <- c('alpha','beta[1]','beta[2]','gamma[1,1]','gamma[3,1]') +expect_equal(match_params('alpha', params_raw),'alpha') +expect_equal(match_params('beta', params_raw), c('beta[1]','beta[2]')) +expect_equal(match_params('gamma[1,1]', params_raw), 'gamma[1,1]') +expect_true(is.null(match_params('fake',params_raw))) +expect_equal(match_params(c('alpha','beta'),params_raw), + c('alpha','beta[1]','beta[2]')) +expect_equal(match_params(c('alpha','fake','beta'),params_raw), + c('alpha','beta[1]','beta[2]')) + +# test that order_samples works correctly-------------------------------------- +order_samples <- jagsUI:::order_samples +samples <- readRDS('coda_samples.Rds') +new_order <- c('beta','mu','alpha') +out <- order_samples(samples, new_order) +expect_equal(class(out), 'mcmc.list') +expect_equal(length(out),length(samples)) +expect_equal(lapply(out,class),lapply(samples,class)) +expect_equal(param_names(out),c('beta',paste0('mu[',1:16,']'),'alpha')) +expect_equal(dim(out[[1]]), c(30,18)) +expect_equal(as.numeric(out[[1]][1,1:2]), + c(0.03690717, 59.78175), tol=1e-4) +expect_equal(order_samples(samples, 'beta'), + order_samples(samples, c('beta','fake'))) +expect_message(order_samples('fake','beta')) +expect_message(test <- order_samples('fake','beta')) +expect_equal(test, 'fake') +one_param <- samples[, 'alpha',drop=FALSE] +expect_equal(order_samples(one_param,'alpha'),one_param) +expect_equal(dim(order_samples(one_param, 'beta')[[1]]),c(30,0)) + + +# test that which_params gets param col indices-------------------------------- +which_params <- jagsUI:::which_params +params_raw <- c('alpha','beta[1]','beta[2]','gamma[1,1]','gamma[3,1]') +expect_equal(which_params('alpha',params_raw),1) +expect_equal(which_params('beta',params_raw),c(2,3)) +expect_equal(which_params('gamma',params_raw),c(4,5)) +expect_null(which_params('kappa',params_raw)) + + +# test that subset_params drops correct params from list----------------------- +subset_params <- jagsUI:::subset_params +samples <- readRDS('coda_samples.Rds') +expect_equal(subset_params(samples), param_names(samples)) +expect_equal(subset_params(samples, 'beta'), param_names(samples)[-2]) +expect_equal(subset_params(samples, c('mu','kappa')), + c('alpha','beta','sigma','deviance')) +expect_equal(subset_params(samples, param_names(samples, simplify=TRUE)), + character(0)) + +# test that mcmc_to_mat converts properly-------------------------------------- +mcmc_to_mat <- jagsUI:::mcmc_to_mat +samples <- readRDS('coda_samples.Rds') +mat <- mcmc_to_mat(samples[, 'alpha']) +expect_true(inherits(mat, 'matrix')) +expect_equal(dim(mat),c(nrow(samples[[1]]),length(samples))) +expect_equal(mat[,1],as.numeric(samples[[1]][,'alpha'])) +one_sample <- readRDS('one_sample.Rds') +mat <- mcmc_to_mat(one_sample[, 'alpha']) +expect_equal(dim(mat), c(1,3)) + +# test that get_inds extracts indices----------------------------------------- +get_inds <- jagsUI:::get_inds +params_raw <- c('beta[1]','beta[2]') +expect_equal(get_inds('beta',params_raw),matrix(c(1,2))) +params_raw <- c('gamma[1,1]','gamma[2,1]','gamma[1,3]') +expect_equal(get_inds('gamma',params_raw), + matrix(c(1,1,2,1,1,3),ncol=2,byrow=T)) +params_raw <- c('kappa[1,1,1]','kappa[2,1,1]','kappa[1,3,1]') +expect_equal(get_inds('kappa',params_raw), + matrix(c(1,1,1,2,1,1,1,3,1),ncol=3,byrow=T)) +params_raw <- 'alpha' +expect_warning(test <- get_inds('alpha',params_raw)[1,1]) +expect_true(is.na(test)) diff --git a/tests/tinytest.R b/tests/tinytest.R new file mode 100644 index 0000000..6c6da60 --- /dev/null +++ b/tests/tinytest.R @@ -0,0 +1,4 @@ +if ( requireNamespace("tinytest", quietly=TRUE) ){ + tinytest::test_package("jagsUI") +} + |