This function makes (ggplot2::ggplot()) figures of the training VLB and the validation IWAE for a list of explain() objects with approach = "vaeac". See setup_approach() for more information about the vaeac approach. Two figures are returned by the function. In the figure, each object in explanation_list gets its own facet, while in the second figure, we plot the criteria in each facet for all objects.

vaeac_plot_eval_crit(
  explanation_list,
  plot_from_nth_epoch = 1,
  plot_every_nth_epoch = 1,
  criteria = c("VLB", "IWAE"),
  plot_type = c("method", "criterion"),
  facet_wrap_scales = "fixed",
  facet_wrap_ncol = NULL
)

Arguments

explanation_list

A list of explain() objects applied to the same data, model, and vaeac must be the used approach. If the entries in the list is named, then the function use these names. Otherwise, it defaults to the approach names (with integer suffix for duplicates) for the explanation objects in explanation_list.

plot_from_nth_epoch

Integer. If we are only plot the results form the nth epoch and so forth. The first epochs can be large in absolute value and make the rest of the plot difficult to interpret.

plot_every_nth_epoch

Integer. If we are only to plot every nth epoch. Usefully to illustrate the overall trend, as there can be a lot of fluctuation and oscillation in the values between each epoch.

criteria

Character vector. The possible options are "VLB", "IWAE", "IWAE_running". Default is the first two.

plot_type

Character vector. The possible options are "method" and "criterion". Default is to plot both.

facet_wrap_scales

String. Should the scales be fixed ("fixed", the default), free ("free"), or free in one dimension ("free_x", "free_y").

facet_wrap_ncol

Integer. Number of columns in the facet wrap.

Value

Either a single ggplot2::ggplot() object or a list of ggplot2::ggplot() objects based on the plot_type parameter.

Details

See Olsen et al. (2022) or the blog post for a summary of the VLB and IWAE.

Author

Lars Henry Berge Olsen

Examples

if (FALSE) {
library(xgboost)
library(data.table)
library(shapr)

data("airquality")
data <- data.table::as.data.table(airquality)
data <- data[complete.cases(data), ]

x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"

ind_x_explain <- 1:6
x_train <- data[-ind_x_explain, ..x_var]
y_train <- data[-ind_x_explain, get(y_var)]
x_explain <- data[ind_x_explain, ..x_var]

# Fitting a basic xgboost model to the training data
model <- xgboost(data = as.matrix(x_train), label = y_train, nround = 100, verbose = FALSE)

# Specifying the phi_0, i.e. the expected prediction without any features
p0 <- mean(y_train)

# Train vaeac with and without paired sampling
explanation_paired <- explain(
  model = model,
  x_explain = x_explain,
  x_train = x_train,
  approach = approach,
  prediction_zero = p0,
  n_samples = 1, # As we are only interested in the training of the vaeac
  vaeac.epochs = 10, # Should be higher in applications.
  vaeac.n_vaeacs_initialize = 1,
  vaeac.width = 16,
  vaeac.depth = 2,
  vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE)
)

explanation_regular <- explain(
  model = model,
  x_explain = x_explain,
  x_train = x_train,
  approach = approach,
  prediction_zero = p0,
  n_samples = 1, # As we are only interested in the training of the vaeac
  vaeac.epochs = 10, # Should be higher in applications.
  vaeac.width = 16,
  vaeac.depth = 2,
  vaeac.n_vaeacs_initialize = 1,
  vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE)
)

# Collect the explanation objects in an named list
explanation_list <- list(
  "Regular sampling" = explanation_regular,
  "Paired sampling" = explanation_paired
)

# Call the function with the named list, will use the provided names
vaeac_plot_eval_crit(explanation_list = explanation_list)

# The function also works if we have only one method,
# but then one should only look at the method plot.
vaeac_plot_eval_crit(
  explanation_list = explanation_list[2],
  plot_type = "method"
)

# Can alter the plot
vaeac_plot_eval_crit(
  explanation_list = explanation_list,
  plot_from_nth_epoch = 2,
  plot_every_nth_epoch = 2,
  facet_wrap_scales = "free"
)

# If we only want the VLB
vaeac_plot_eval_crit(
  explanation_list = explanation_list,
  criteria = "VLB",
  plot_type = "criterion"
)

# If we want only want the criterion version
tmp_fig_criterion <-
  vaeac_plot_eval_crit(explanation_list = explanation_list, plot_type = "criterion")

# Since tmp_fig_criterion is a ggplot2 object, we can alter it
# by, e.g,. adding points or smooths with se bands
tmp_fig_criterion + ggplot2::geom_point(shape = "circle", size = 1, ggplot2::aes(col = Method))
tmp_fig_criterion$layers[[1]] <- NULL
tmp_fig_criterion + ggplot2::geom_smooth(method = "loess", formula = y ~ x, se = TRUE) +
  ggplot2::scale_color_brewer(palette = "Set1") +
  ggplot2::theme_minimal()
}