diff options
author | Ken Kellner <ken@kenkellner.com> | 2021-09-13 17:06:50 -0400 |
---|---|---|
committer | Ken Kellner <ken@kenkellner.com> | 2021-09-13 17:06:50 -0400 |
commit | 3b6274c8ca6a0351bcf97a641f1f819be8e09fac (patch) | |
tree | 9117956d60ec0cababbbf56e006787ecc880151f | |
parent | fd041c9628001d58f4256336f8eeed5899ddf69b (diff) |
More changes to plot_marginal: add draws argument, return ggplot when only one covariate, fix factor problem
-rw-r--r-- | R/plot_marginal.R | 41 |
1 files changed, 26 insertions, 15 deletions
diff --git a/R/plot_marginal.R b/R/plot_marginal.R index fedd0b6..bd2ea2e 100644 --- a/R/plot_marginal.R +++ b/R/plot_marginal.R @@ -12,6 +12,8 @@ setGeneric("plot_marginal", function(object, ...) standardGeneric("plot_marginal #' @param submodel Submodel to get plots for, for example \code{"det"} #' @param covariate Plot a specific covariate; provide the name as a string #' @param level Probability mass to include in the uncertainty interval +#' @param draws Number of draws from the posterior to use when generating the +#' plots. If fewer than \code{draws} are available, they are all used #' @param smooth Smoother span (f) value used for LOWESS smoothing of the #' upper and lower credible interval bounds for a continuous covariate. #' No smoothing is done if \code{NULL}. A reasonable value to try is 0.05. @@ -19,7 +21,8 @@ setGeneric("plot_marginal", function(object, ...) standardGeneric("plot_marginal #' with caution #' @param ... Currently ignored #' -#' @return A \code{ggplot}, potentially with multiple panels, one per covariate +#' @return A \code{ggplot} if a single covariate is plotted, or an object +#' of class \code{grob} if there are multiple covariates/panels #' #' @aliases plot_marginal #' @include fit.R @@ -27,46 +30,53 @@ setGeneric("plot_marginal", function(object, ...) standardGeneric("plot_marginal #' @importFrom ggplot2 geom_errorbar #' @export setMethod("plot_marginal", "ubmsFit", function(object, submodel, covariate=NULL, - level=0.95, smooth=NULL, ...){ + level=0.95, draws=1000, smooth=NULL, ...){ sm <- object[submodel] if(!is.null(covariate)){ - plots <- list(marginal_covariate_plot(object, submodel, covariate, level, smooth)) + plots <- list(marginal_covariate_plot(object, submodel, covariate, level, + draws, smooth)) } else { vars <- all.vars(remove_offset(lme4::nobars(sm@formula))) if(length(vars) == 0) stop("No covariates in this submodel") plots <- lapply(vars, function(x){ - marginal_covariate_plot(object, submodel, x, level, smooth) + marginal_covariate_plot(object, submodel, x, level, + draws, smooth) }) } - dims <- c(1,1) nplots <- length(plots) - if(nplots > 1){ - dims <- c(ceiling(sqrt(nplots)), round(sqrt(nplots))) + if(nplots == 1){ + out_plot <- plots[[1]] + + theme(axis.title.y=ggplot2::element_text(angle=90)) + + ggplot2::labs(y = sm@name) + return(out_plot) } + dims <- c(ceiling(sqrt(nplots)), round(sqrt(nplots))) gridExtra::grid.arrange(grobs=plots, nrow=dims[1], ncol=dims[2], left=grid::textGrob(sm@name, rot=90, vjust=0.5, gp=grid::gpar(fontsize=14))) }) -marginal_covariate_plot <- function(object, submodel, covariate, level=0.95, smooth=NULL){ +marginal_covariate_plot <- function(object, submodel, covariate, level=0.95, + draws, smooth=NULL){ sm <- object[submodel] quant <- c((1-level)/2, level+(1-level)/2) vars <- all.vars(lme4::nobars(sm@formula)) stopifnot(covariate %in% vars) is_factor <- inherits(sm@data[[covariate]], "factor") if(is_factor){ - return(marg_factor_plot(object, submodel, covariate, quant)) + return(marg_factor_plot(object, submodel, covariate, quant, draws)) } - marg_numeric_plot(object, submodel, covariate, quant, smooth) + marg_numeric_plot(object, submodel, covariate, quant, draws, smooth) } -marg_numeric_plot <- function(object, submodel, covariate, quant, smooth=NULL){ +marg_numeric_plot <- function(object, submodel, covariate, quant, + draws, smooth=NULL){ sm <- object[submodel] - samples <- get_samples(object, 1000) + samples <- get_samples(object, draws) newdata <- get_baseline_df(sm)[rep(1, 1000),,drop=FALSE] var_range <- range(sm@data[[covariate]], na.rm=TRUE) newdata[[covariate]] <- seq(var_range[1], var_range[2], length.out=1000) @@ -88,12 +98,13 @@ marg_numeric_plot <- function(object, submodel, covariate, quant, smooth=NULL){ theme(axis.title.y=element_blank()) } -marg_factor_plot <- function(object, submodel, covariate, quant){ +marg_factor_plot <- function(object, submodel, covariate, quant, draws){ sm <- object[submodel] - samples <- get_samples(object, 1000) + samples <- get_samples(object, draws) nlev <- nlevels(sm@data[[covariate]]) newdata <- get_baseline_df(sm)[rep(1, nlev),,drop=FALSE] - newdata[[covariate]] <- levels(sm@data[[covariate]]) + newdata[[covariate]] <- factor(levels(sm@data[[covariate]]), + levels=levels(sm@data[[covariate]])) plot_df <- get_margplot_data(object, submodel, covariate, quant, samples, newdata) |