aboutsummaryrefslogtreecommitdiff
path: root/R/traceplot.R
blob: e029d182ba03d4d748f605de08f77c18f11432dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#Get traceplots for series of parameters
traceplot <- function(x, parameters=NULL, Rhat_min=NULL,
                      per_plot=9, ask=NULL){

  #Check input class and get basic plot settings
  check_class(x)
  if(is.null(ask))
    ask <- grDevices::dev.interactive(orNone = TRUE)
  plot_info <- get_plot_info(x, parameters, per_plot, ask, Rhat_min)

  #Handle par()
  old_par <- graphics::par(plot_info$new_par)
  on.exit(graphics::par(old_par))

  #Generate plot
  n <- length(plot_info$params)
  for (i in 1:n){
    m_labels <- (i %% plot_info$per_plot == 0) || (i==n)
    param_trace(x, plot_info$params[i], m_labels=m_labels)
  }

}

#Traceplot for single parameter
param_trace <- function(x, parameter, m_labels=FALSE){

  #Get samples and Rhat values
  vals <- mcmc_to_mat(x$samples, parameter)
  Rhat <- sprintf("%.3f",round(x$summary[parameter, 'Rhat'],3))

  #Draw plot
  cols <- grDevices::rainbow(ncol(vals))
  graphics::matplot(1:nrow(vals), vals, type='l', lty=1, col=cols,
                 xlab='Iterations', ylab='Value',
                 main=paste('Trace of',parameter))
  # graphics::plot(1:nrow(vals), vals[,1], type='l', col=cols[1],
                 # ylim=range(vals), xlab='Iterations', ylab='Value',
                 # main=paste('Trace of',parameter))
  # for (i in 2:ncol(vals)) graphics::lines(1:nrow(vals), vals[,i], col=cols[i]) # this fails with 1 chain

  #Add Rhat value
  graphics::legend('bottomright', legend=bquote(hat(R) == .(Rhat)),
                   bty='o', bg='white', cex=1.2)

  #Add margin labels if necessary
  if(m_labels){
    graphics::mtext("Iteration", side=1, line=1.5, outer=TRUE)
    graphics::mtext("Value", side=2, line=1.5, outer=TRUE)
  }
}

#General function for setting up plots
# get_plot_info now has its own file