---
title: "Explaining nestedcv models with Shapley values"
author: "Myles Lewis"
output:
  html_document:
fig_width: 6
vignette: >
  %\VignetteIndexEntry{Explaining nestedcv models with Shapley values}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include = FALSE}
knitr::knit_hooks$set(pngquant = knitr::hook_pngquant)
knitr::opts_chunk$set(
  collapse = TRUE,
  warning = FALSE
)
```

`nestedcv` provides two methods for understanding fitted models. The simplest of
these is to plot variable importance. The newer method is to calculate Shapley
values for each predictor.

### Variable importance & stability

For regression model systems such as glmnet variable importance is represented
as the coefficients of the model scaled by absolute value from largest to
smallest. However, the outer folds of nested CV allow us to show the variance of
model coefficients across each outer fold plus the final model, and hence see
how stable the model is. We can also overlay how often predictors are selected
in each model to give a sense of the stability of predictor selection.

In the example below using the Boston housing dataset, a glmnet regression model
is fitted and the variable importance for predictors is shown based on the
coefficients in the final model and tuned models from 10 outer folds. `min_1se`
is set to 1, which is the equivalent of specifying `s = "lambda.1se"` with
`glmnet`, to encourage a more sparse model.

```{r}
library(nestedcv)
library(mlbench)  # Boston housing dataset

data(BostonHousing2)
dat <- BostonHousing2
y <- dat$cmedv
x <- subset(dat, select = -c(cmedv, medv, town, chas))

# Fit a glmnet model using nested CV
set.seed(1, "L'Ecuyer-CMRG")
fit <- nestcv.glmnet(y, x, family = "gaussian",
                     min_1se = 1, alphaSet = 1, cv.cores = 2)
vs <- var_stability(fit)
vs
```

Variable stability can be plotted using `plot_var_stability()`.

```{r, fig.dim = c(10, 5)}
p1 <- plot_var_stability(fit)

# overlay directionality using colour
p2 <- plot_var_stability(fit, final = FALSE, direction = 1)

# or show directionality with the sign of the variable importance
# plot_var_stability(fit, final = FALSE, percent = F)

ggpubr::ggarrange(p1, p2, ncol=2)
```

By default only the predictors chosen in the final model are shown. If the
argument `final` is set to `FALSE` all predictors are shown to help understand
how often they are selected which is helpful when pushing sparsity in models.
Original coefficients can be shown instead of being scaled as a percentage by
setting `percent = FALSE`.

The frequency with which each variable is selected in outer folds as well as the
final model is shown by as bubble size, which is helpful for sparse models or
with filters to determine how often variables end up in the model in each fold.

We can overlay directionality using either colour (`direction = 1`) or the sign
of the variable importance to splay the plot (`direction = 2`).

For glmnet, the direction of effect is taken directly from the sign of model
coefficients. For `caret` models, direction of effect is not readily available,
so as a substitute, the directionality of each predictor is determined by the
function `var_direction()` using the sign of a *t*-test for binary
classification or the sign of regression coefficient for continuous outcomes
(not available for multiclass caret models). To better understand relationship
and direction of effect of each predictor within the final model, we recommend
using SHAP values.

Alternatively a barplot is available using `barplot_var_stability()`.

The `caret` package allows variable importance to be calculated for other models
types. This is model dependent. For example, with random forest variable
importance is usually calculated as the mean decrease in Gini impurity each time
a variable is chosen in a tree. With caret models, you may have to load the
appropriate package for that model to calculate the variable importance, e.g.
this is necessary for GBM models with the `gbm` package.

To change the colour scheme for `direction = 1` overwrite `scale_fill_manual()`.

```{r, eval=FALSE}
# change bubble colour scheme
p1 + scale_fill_manual(values=c("orange", "green3"))
```

### Ranking variables

Variable importance can be displayed by showing how variables are ranked by
variable importance across the outer CV folds.

```{r, fig.dim = c(10, 5)}
# beeswarm plot of variable ranks 
p1 <- plot_var_ranks(fit)

# histogram of variable ranks
p2 <- hist_var_ranks(fit)

ggpubr::ggarrange(p1, p2, ncol=2)
```

All of the above plotting functions can also be applied to repeated nested CV
objects generated using `repeatcv()` provided the argument `keep = TRUE`.

### Explainable AI with Shapley values

The original implementation of [shap](https://github.com/shap/shap) by
Scott Lundberg is a python package. We suggest using the R package
[fastshap](https://cran.r-project.org/package=fastshap) for examining `nestedcv`
models since it works with classification or regression as well as any model
type (regression such as glmnet or tree based such as random forest, GBM or
xgboost).

In the example below using the same glmnet regression model, the variable
importance for predictors is measured using Shapley values. The function
`explain()` from the `fastshap` package needs a wrapper function for prediction
using the model. `nestedcv` provides `pred_nestcv_glmnet` which is a wrapper
function for binary classification or regression with `nestcv.glmnet` fitted
models.

```{r}
library(fastshap)

# Generate SHAP values using fastshap::explain
# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet, nsim = 5)

# Plot overall variable importance
plot_shap_bar(sh, x)
```

`plot_shap_bar()` ranks predictors in terms of variable importance calculated as
mean absolute SHAP value. The plot also overlays colour for the direction of the
main effect of each variable on the model, based on correlating the value of
each variable against the SHAP value to see if the overall correlation for that
variable is positive or negative. For regression models such as glmnet this
corresponds to the sign of each coefficient. For more complex models with
interactions the direction of effect may be variable and non-linear.

Note that these SHAP plots only show the final fitted model (on the whole data),
whereas the variable stability plots examine models across the outer CV folds as
well as the final model.

Be careful if you have massive numbers of predictors in `x` (see
[Troubleshooting](#troubleshooting)).

`nestedcv` also provides a quick plotting function `plot_shap_beeswarm` for
generating ggplot2 beeswarm plots similar to those made by the original python
`shap` package.

```{r}
# Plot beeswarm plot
plot_shap_beeswarm(sh, x, size = 1)
```

`size` can be set to control the size of points. `cex` controls the amount of
overlap of the beeswarm. `scheme` controls the colour scheme as a vector of 3
colours. Alternatively try the [shapviz](https://cran.r-project.org/package=shapviz) 
package.

The process for a caret model fitted using `nestedcv` is similar. Use the
`pred_train` wrapper when calling `explain()`.

```{r, eval=FALSE}
# Only 3 outer folds to speed up process
fit <- nestcv.train(y, x,
                    method = "gbm",
                    n_outer_folds = 3, cv.cores = 2)

# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_train, nsim = 5)
plot_shap_beeswarm(sh, x, size = 1)
```

For multinomial classification, a wrapper is needed for each class. For
`nestcv.glmnet` models, we provide a function factory which acts as a wrapper
for each class: `pred_nestcv_glmnet_class(1)`, `pred_nestcv_glmnet_class(2)`
etc.

For `nestcv.train()` models, similarly we provide `pred_train_class(cl)`,
where `cl` specifies which class is being predicted.

As a toy example, we show the iris dataset which has 3 classes.

```{r, fig.width = 9, fig.height = 3.5}
library(ggplot2)
data("iris")
dat <- iris
y <- dat$Species
x <- dat[, 1:4]

# Only 3 outer folds to speed up process
fit <- nestcv.glmnet(y, x, family = "multinomial", n_outer_folds = 3, alphaSet = 0.6)

# SHAP values for each of the 3 classes
sh1 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class(1), nsim = 5)
sh2 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class(2), nsim = 5)
sh3 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class(3), nsim = 5)

s1 <- plot_shap_bar(sh1, x, sort = FALSE) +
  ggtitle("Setosa")
s2 <- plot_shap_bar(sh2, x, sort = FALSE) +
  ggtitle("Versicolor")
s3 <- plot_shap_bar(sh3, x, sort = FALSE) +
  ggtitle("Virginica")

ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "bottom", common.legend = TRUE)
```

Or using beeswarm plots.

```{r, fig.width = 9.5, fig.height = 3.5}
s1 <- plot_shap_beeswarm(sh1, x, sort = FALSE, cex = 0.7) +
  ggtitle("Setosa")
s2 <- plot_shap_beeswarm(sh2, x, sort = FALSE, cex = 0.7) +
  ggtitle("Versicolor")
s3 <- plot_shap_beeswarm(sh3, x, sort = FALSE, cex = 0.7) +
  ggtitle("Virginica")

ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "right", common.legend = TRUE)
```

### One-hot encoding for categorical predictors

Binary factors generally should not have any handling problems as most functions
in `nestedcv` convert them to 0 and 1, and directionality can be interpreted.
However, multi-level factors with 3 or more levels are not clearly interpretable
with SHAP values. We recommend these are one-hot encoded with dummy variables
using the function `one_hot()`. This function accepts dataframes and can encode
multiple factors and character columns, producing a matrix as a result.

### Troubleshooting

`explain()` runs progressively slowly with larger numbers of input predictors.
So if you have a very large number of predictors in `x` it is much better to
pass only the subset of predictors in the final model. These are stored in
fitted `nestedcv` objects in list element `xsub`. Or you can subset `x` using
`final_vars`. For example:

```{r eval=FALSE}
sh <- explain(fit, X = fit$xsub, pred_wrapper = pred_nestcv_glmnet, nsim = 5)
plot_shap_bar(sh, fit$xsub)

sh <- explain(fit, X = x[, fit$final_vars], pred_wrapper = pred_nestcv_glmnet, nsim = 5)
plot_shap_bar(sh, x[, fit$final_vars])
```

### Citation

If you use this package, please cite as:

Lewis MJ, Spiliopoulou A, Goldmann K, Pitzalis C, McKeigue P, Barnes MR (2023).
nestedcv: an R package for fast implementation of nested cross-validation with
embedded feature selection designed for transcriptomics and high dimensional
data. *Bioinformatics Advances*. https://doi.org/10.1093/bioadv/vbad048

### References

Brandon Greenwell. [fastshap](https://cran.r-project.org/package=fastshap): Fast
Approximate Shapley Values

Lundberg SM & Lee SI (2017). A Unified Approach to Interpreting Model
Predictions. *Advances in Neural Information Processing Systems* 30; 4768–4777.
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf

Lundberg SM et al (2020). From Local Explanations to Global Understanding with
Explainable AI for Trees. *Nat Mach Intell* 2020; 2(1): 56-67.
https://arxiv.org/abs/1905.04610
