aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKen Kellner <ken@kenkellner.com>2021-09-13 17:06:50 -0400
committerKen Kellner <ken@kenkellner.com>2021-09-13 17:06:50 -0400
commit3b6274c8ca6a0351bcf97a641f1f819be8e09fac (patch)
tree9117956d60ec0cababbbf56e006787ecc880151f
parentfd041c9628001d58f4256336f8eeed5899ddf69b (diff)
More changes to plot_marginal: add draws argument, return ggplot when only one covariate, fix factor problem
-rw-r--r--R/plot_marginal.R41
1 files changed, 26 insertions, 15 deletions
diff --git a/R/plot_marginal.R b/R/plot_marginal.R
index fedd0b6..bd2ea2e 100644
--- a/R/plot_marginal.R
+++ b/R/plot_marginal.R
@@ -12,6 +12,8 @@ setGeneric("plot_marginal", function(object, ...) standardGeneric("plot_marginal
#' @param submodel Submodel to get plots for, for example \code{"det"}
#' @param covariate Plot a specific covariate; provide the name as a string
#' @param level Probability mass to include in the uncertainty interval
+#' @param draws Number of draws from the posterior to use when generating the
+#' plots. If fewer than \code{draws} are available, they are all used
#' @param smooth Smoother span (f) value used for LOWESS smoothing of the
#' upper and lower credible interval bounds for a continuous covariate.
#' No smoothing is done if \code{NULL}. A reasonable value to try is 0.05.
@@ -19,7 +21,8 @@ setGeneric("plot_marginal", function(object, ...) standardGeneric("plot_marginal
#' with caution
#' @param ... Currently ignored
#'
-#' @return A \code{ggplot}, potentially with multiple panels, one per covariate
+#' @return A \code{ggplot} if a single covariate is plotted, or an object
+#' of class \code{grob} if there are multiple covariates/panels
#'
#' @aliases plot_marginal
#' @include fit.R
@@ -27,46 +30,53 @@ setGeneric("plot_marginal", function(object, ...) standardGeneric("plot_marginal
#' @importFrom ggplot2 geom_errorbar
#' @export
setMethod("plot_marginal", "ubmsFit", function(object, submodel, covariate=NULL,
- level=0.95, smooth=NULL, ...){
+ level=0.95, draws=1000, smooth=NULL, ...){
sm <- object[submodel]
if(!is.null(covariate)){
- plots <- list(marginal_covariate_plot(object, submodel, covariate, level, smooth))
+ plots <- list(marginal_covariate_plot(object, submodel, covariate, level,
+ draws, smooth))
} else {
vars <- all.vars(remove_offset(lme4::nobars(sm@formula)))
if(length(vars) == 0) stop("No covariates in this submodel")
plots <- lapply(vars, function(x){
- marginal_covariate_plot(object, submodel, x, level, smooth)
+ marginal_covariate_plot(object, submodel, x, level,
+ draws, smooth)
})
}
- dims <- c(1,1)
nplots <- length(plots)
- if(nplots > 1){
- dims <- c(ceiling(sqrt(nplots)), round(sqrt(nplots)))
+ if(nplots == 1){
+ out_plot <- plots[[1]] +
+ theme(axis.title.y=ggplot2::element_text(angle=90)) +
+ ggplot2::labs(y = sm@name)
+ return(out_plot)
}
+ dims <- c(ceiling(sqrt(nplots)), round(sqrt(nplots)))
gridExtra::grid.arrange(grobs=plots, nrow=dims[1], ncol=dims[2],
left=grid::textGrob(sm@name, rot=90, vjust=0.5,
gp=grid::gpar(fontsize=14)))
})
-marginal_covariate_plot <- function(object, submodel, covariate, level=0.95, smooth=NULL){
+marginal_covariate_plot <- function(object, submodel, covariate, level=0.95,
+ draws, smooth=NULL){
sm <- object[submodel]
quant <- c((1-level)/2, level+(1-level)/2)
vars <- all.vars(lme4::nobars(sm@formula))
stopifnot(covariate %in% vars)
is_factor <- inherits(sm@data[[covariate]], "factor")
if(is_factor){
- return(marg_factor_plot(object, submodel, covariate, quant))
+ return(marg_factor_plot(object, submodel, covariate, quant, draws))
}
- marg_numeric_plot(object, submodel, covariate, quant, smooth)
+ marg_numeric_plot(object, submodel, covariate, quant, draws, smooth)
}
-marg_numeric_plot <- function(object, submodel, covariate, quant, smooth=NULL){
+marg_numeric_plot <- function(object, submodel, covariate, quant,
+ draws, smooth=NULL){
sm <- object[submodel]
- samples <- get_samples(object, 1000)
+ samples <- get_samples(object, draws)
newdata <- get_baseline_df(sm)[rep(1, 1000),,drop=FALSE]
var_range <- range(sm@data[[covariate]], na.rm=TRUE)
newdata[[covariate]] <- seq(var_range[1], var_range[2], length.out=1000)
@@ -88,12 +98,13 @@ marg_numeric_plot <- function(object, submodel, covariate, quant, smooth=NULL){
theme(axis.title.y=element_blank())
}
-marg_factor_plot <- function(object, submodel, covariate, quant){
+marg_factor_plot <- function(object, submodel, covariate, quant, draws){
sm <- object[submodel]
- samples <- get_samples(object, 1000)
+ samples <- get_samples(object, draws)
nlev <- nlevels(sm@data[[covariate]])
newdata <- get_baseline_df(sm)[rep(1, nlev),,drop=FALSE]
- newdata[[covariate]] <- levels(sm@data[[covariate]])
+ newdata[[covariate]] <- factor(levels(sm@data[[covariate]]),
+ levels=levels(sm@data[[covariate]]))
plot_df <- get_margplot_data(object, submodel, covariate, quant,
samples, newdata)