vaeac
modelsR/approach_vaeac.R
plot_vaeac_eval_crit.Rd
This function makes (ggplot2::ggplot()
) figures of the training VLB and the validation IWAE for a list
of explain()
objects with approach = "vaeac"
. See setup_approach()
for more information about the
vaeac
approach. Two figures are returned by the function. In the figure, each object in explanation_list
gets
its own facet, while in the second figure, we plot the criteria in each facet for all objects.
A list of explain()
objects applied to the same data, model, and
vaeac
must be the used approach. If the entries in the list is named, then the function use
these names. Otherwise, it defaults to the approach names (with integer suffix for duplicates)
for the explanation objects in explanation_list
.
Integer. If we are only plot the results form the nth epoch and so forth. The first epochs can be large in absolute value and make the rest of the plot difficult to interpret.
Integer. If we are only to plot every nth epoch. Usefully to illustrate the overall trend, as there can be a lot of fluctuation and oscillation in the values between each epoch.
Character vector. The possible options are "VLB", "IWAE", "IWAE_running". Default is the first two.
Character vector. The possible options are "method" and "criterion". Default is to plot both.
String. Should the scales be fixed ("fixed
", the default),
free ("free
"), or free in one dimension ("free_x
", "free_y
").
Integer. Number of columns in the facet wrap.
Either a single ggplot2::ggplot()
object or a list of ggplot2::ggplot()
objects based on the
plot_type
parameter.
See Olsen et al. (2022) or the blog post for a summary of the VLB and IWAE.
if (FALSE) { # \dontrun{
library(xgboost)
library(data.table)
library(shapr)
data("airquality")
data <- data.table::as.data.table(airquality)
data <- data[complete.cases(data), ]
x_var <- c("Solar.R", "Wind", "Temp", "Month")
y_var <- "Ozone"
ind_x_explain <- 1:6
x_train <- data[-ind_x_explain, ..x_var]
y_train <- data[-ind_x_explain, get(y_var)]
x_explain <- data[ind_x_explain, ..x_var]
# Fitting a basic xgboost model to the training data
model <- xgboost(data = as.matrix(x_train), label = y_train, nround = 100, verbose = FALSE)
# Specifying the phi_0, i.e. the expected prediction without any features
p0 <- mean(y_train)
# Train vaeac with and without paired sampling
explanation_paired <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = approach,
phi0 = p0,
n_MC_samples = 1, # As we are only interested in the training of the vaeac
vaeac.epochs = 10, # Should be higher in applications.
vaeac.n_vaeacs_initialize = 1,
vaeac.width = 16,
vaeac.depth = 2,
vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE)
)
explanation_regular <- explain(
model = model,
x_explain = x_explain,
x_train = x_train,
approach = approach,
phi0 = p0,
n_MC_samples = 1, # As we are only interested in the training of the vaeac
vaeac.epochs = 10, # Should be higher in applications.
vaeac.width = 16,
vaeac.depth = 2,
vaeac.n_vaeacs_initialize = 1,
vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE)
)
# Collect the explanation objects in an named list
explanation_list <- list(
"Regular sampling" = explanation_regular,
"Paired sampling" = explanation_paired
)
# Call the function with the named list, will use the provided names
plot_vaeac_eval_crit(explanation_list = explanation_list)
# The function also works if we have only one method,
# but then one should only look at the method plot.
plot_vaeac_eval_crit(
explanation_list = explanation_list[2],
plot_type = "method"
)
# Can alter the plot
plot_vaeac_eval_crit(
explanation_list = explanation_list,
plot_from_nth_epoch = 2,
plot_every_nth_epoch = 2,
facet_wrap_scales = "free"
)
# If we only want the VLB
plot_vaeac_eval_crit(
explanation_list = explanation_list,
criteria = "VLB",
plot_type = "criterion"
)
# If we want only want the criterion version
tmp_fig_criterion <-
plot_vaeac_eval_crit(explanation_list = explanation_list, plot_type = "criterion")
# Since tmp_fig_criterion is a ggplot2 object, we can alter it
# by, e.g,. adding points or smooths with se bands
tmp_fig_criterion + ggplot2::geom_point(shape = "circle", size = 1, ggplot2::aes(col = Method))
tmp_fig_criterion$layers[[1]] <- NULL
tmp_fig_criterion + ggplot2::geom_smooth(method = "loess", formula = y ~ x, se = TRUE) +
ggplot2::scale_color_brewer(palette = "Set1") +
ggplot2::theme_minimal()
} # }