---
title: "Generalized Cross-Validation with Origami"
author: "Jeremy Coyle & Nima Hejazi"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
bibliography: refs.bib
vignette: >
  %\VignetteIndexEntry{Generalized Cross-Validation with Origami}
  %\VignetteEngine{knitr::rmarkdown}
  \usepackage[utf8]{inputenc}
---

## Introduction

Cross-validation is an essential tool for evaluating how any given data analytic
procedure extends from a sample to the target population from which the sample
is derived. It has seen widespread application in all facets of statistics,
perhaps most notably statistical machine learning.  When used for model
selection, cross-validation has powerful optimality properties
[@vaart2006oracle], [@vdl2007super].

<!--
and for CV-TMLE based parameter estimation, elimination of difficult-to-assess
empirical process conditions from CV-TMLE step [@Zheng:2010ua].
-->

Cross-validation works by partitioning a sample into complementary subsets,
applying a particular data analytic (statistical) routine on a subset (the
"training" set), and evaluating the routine of choice on the complementary
subset (the "testing" set). This procedure is repeated across multiple
partitions of the data. A variety of different  partitioning schemes exist, such
as V-fold cross-validation and bootstrap cross-validation, many of which are
supported by `origami`. The `origami` package provides a suite of tools that
generalize  the application of cross-validation to arbitrary data analytic
procedures. The use of `origami` is best illustrated by example.

---

## Cross-validation with linear regression

We'll start by examining a fairly simple data set:
```{r load_data}
data(mtcars)
head(mtcars)
```

One might be interested in examining how the efficiency of a car, as measured by
miles-per-gallon (mpg), is explained by various technical aspects of the car,
with data across a variety of different models of cars. Linear regression is
perhaps the simplest statistical procedure that could be used to make such
deductions. Let's try it out:
```{r linear_mod}
lm_mod <- lm(mpg ~ ., data = mtcars)
summary(lm_mod)
```

We can assess how well the model fits the data by comparing the predictions of
the linear model to the true outcomes observed in the data set. This is the
well known (and standard) mean squared error. We can extract that from the `lm` 
model object like so:

```{r get_naive_error}
err <- mean(resid(lm_mod)^2)
```

The mean squared error is `r err`. There is an important problem that arises
when we assess the model in this way -- that is, we have trained our linear
regression model on the full data set and assessed the error on the full data
set, using up all of our data. We, of course, are generally not interested in
how well the model explains variation in the observed data; rather, we are
interested in how the explanation provided by the model generalizes to a target
population from which the sample is presumably derived. Having used all of our
available data, we cannot honestly evaluate how well the model fits (and thus
explains) variation at the population level.

To resolve this issue, cross-validation allows for a particular procedure (e.g.,
linear regression) to be implemented over subsets of the data, evaluating how
well the procedure fits on a testing ("validation") set, thereby providing an
honest evaluation of the error.

We can easily add cross-validation to our linear regression procedure using
`origami`. First, let us define a new function to perform linear regression on a
specific partition of the data (called a "fold"):

```{r define_fun_cv_lm}
cv_lm <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # split up data into training and validation sets
  train_data <- training(data)
  valid_data <- validation(data)

  # fit linear model on training set and predict on validation set
  mod <- lm(as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # capture results to be returned as output
  out <- list(
    coef = data.frame(t(coef(mod))),
    SE = ((preds - valid_data[, out_var_ind])^2)
  )
  return(out)
}
```

Our `cv_lm` function is rather simple: we merely split the available data into a
training and validation sets, using the eponymous functions provided in
`origami`, fit the linear model on the training set, and evaluate the model on
the testing set. This is a simple example of what `origami` considers to be
`cv_fun`s -- functions for using cross-validation to perform a particular
routine over an input data set. Having defined such a function, we can simply
generate a set of partitions using `origami`'s `make_folds` function, and apply
our `cv_lm` function over the resultant `folds` object. Below, we replicate the
resubstitution estimate of the error -- we did this "by hand" above -- using
the functions `make_folds` and `cv_lm`.

```{r load_pkgs}
library(origami)
library(stringr) # used in defining the cv_lm function above
```

```{r cv_lm_resub}
# resubstitution estimate
resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]]
resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .")
mean(resub_results$SE)
```

This (very nearly) matches the estimate of the error that we obtained above.

We can more honestly evaluate the error by _V-fold cross-validation_, which
partitions the data into __v subsets__, fitting the model on $v - 1$ of the
subsets and evaluating on the subset that was held out for testing. This is
repeated such that each subset is used for testing. We can easily apply our
`cv_lm` function using `origami`'s `cross_validate` (n.b., by default this
performs 10-fold cross-validation):

```{r cv_lm_cross_valdate}
# cross-validated estimate
folds <- make_folds(mtcars)
cvlm_results <- cross_validate(
  cv_fun = cv_lm,
  folds = folds,
  data = mtcars,
  reg_form = "mpg ~ ."
)
mean(cvlm_results$SE)
```

Having performed 10-fold cross-validation, we quickly notice that our previous
estimate of the model error (by resubstitution) was quite optimistic. The honest
estimate of the error is several times larger.

---

## General workflow

Generally, `cross_validate` usage will mirror the workflow in the above example.
First, the user must define folds and a function that operates on each fold.
Once these are passed to `cross_validate`, the function will map the function
across the folds, and combine the results in a reasonable way. More details on
each step of this process will be given below.

### Define folds

The `folds` object passed to `cross_validate` is a list of folds. Such lists can
be generated using the `make_folds` function. Each `fold` consists of a list
with a `training` index vector, a `validation` index vector, and a `fold_index`
(its order in the list of folds). This function supports a variety of
cross-validation schemes including _v-fold_ and _bootstrap_ cross-validation as
well as time series methods like _"Rolling Window"_. It can balance across levels of a
variable (`stratify_ids`), and it can also keep all observations from the same
independent unit together (`cluster_ids`). See the documentation of the `make_folds`
function for details about supported cross-validation schemes and arguments.

### Define fold function

The `cv_fun` argument to `cross_validate` is a function that will perform some
operation on each fold. The first argument to this function must be `fold`,
which will receive an individual fold object to operate on. Additional arguments
can be passed to `cv_fun` using the `...` argument to `cross_validate`. Within
this function, the convenience functions `training`, `validation` and
`fold_index` can return the various components of a fold object. They do this by
retrieving the `fold` object from their calling environment. It can also be
specified directly. If `training` or `validation` is passed an object, it will
index into it in a sensible way. For instance, if it is a vector, it will index
the vector directly. If it is a `data.frame` or `matrix`, it will index rows.
This allows the user to easily partition data into training and validation sets.
This fold function must return a named list of results containing whatever
fold-specific outputs are generated.

### Apply `cross_validate`

After defining folds, `cross_validate` can be used to map the `cv_fun` across
the `folds` using `future_lapply`. This means that it can be easily parallelized
by specifying a parallelization scheme (i.e., a `plan`). See the [`future`
package](https://github.com/futureverse/future) for more details.

The application of `cross_validate` generates a list of results. As described
above, each call to `cv_fun` itself returns a list of results, with different
elements for each type of result we care about. The main loop generates a list
of these individual lists of results (a sort of "meta-list"). This "meta-list"
is then inverted such that there is one element per result type (this too is a
list of the results for each fold). By default, `combine_results` is used to
combine these results type lists.

For instance, in the above `mtcars` example, the results type lists contains one
`coef` `data.frame` from each fold. These are `rbind`ed together to form one
`data.frame` containing the `coefs` from all folds in different rows. How
results are combined is determined automatically by examining the data types
of the results from the first fold. This can be modified by specifying a list of
arguments to `.combine_control`. See the help for `combine_results` for more
details. In most cases, the defaults should suffice.

---

## Cross-validation with random forests

To examine `origami` further, let us return to our example analysis using the
`mtcars` data set. Here, we will write a new `cv_fun` type object. As an
example, we will use L. Breiman's `randomForest`:

```{r cv_fun_randomForest}
cv_rf <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # define training and validation sets based on input object of class "folds"
  train_data <- training(data)
  valid_data <- validation(data)

  # fit Random Forest regression on training set and predict on holdout set
  mod <- randomForest(formula = as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # define output object to be returned as list (for flexibility)
  out <- list(
    coef = data.frame(mod$coefs),
    SE = ((preds - valid_data[, out_var_ind])^2)
  )
  return(out)
}
```

Above, in writing our `cv_rf` function to cross-validate `randomForest`, we used
our previous function `cv_lm` as an example. For now, individual `cv_fun`s must
be written by hand; however, in future releases, a wrapper may be available to
support auto-generating `cv_fun`s to be used with `origami`.

Below, we use `cross_validate` to apply our new `cv_rf` function over the `folds`
object generated by `make_folds`.

```{r}
library(randomForest)
folds <- make_folds(mtcars)
cvrf_results <- cross_validate(
  cv_fun = cv_rf,
  folds = folds,
  data = mtcars,
  reg_form = "mpg ~ ."
)
mean(cvrf_results$SE)
```

Using 10-fold cross-validation (the default), we obtain an honest estimate of
the prediction error of random forests. From this, we gather that the use of
`origami`'s `cross_validate` procedure can be generalized to arbitrary esimation
techniques, given availability of an appropriate `cv_fun` function.

---

## Cross-validation with dependence: time series

Cross-validation can also be used for forecast model selection in a time series
setting. Here, the partitioning scheme mirrors the application of the
forecasting model: We'll train the data on past observations (either all
available or a recent subset), and then use the model forecast (predict), the
next few observations. Consider the `AirPassengers` dataset, a monthly time
series of passenger air traffic in thousands of people.

```{r}
data(AirPassengers)
print(AirPassengers)
```

Suppose we want to pick between two forecasting models, `stl`, and `arima` (the
details of these models are not important for this example). We can do that by
evaluating their forecasting performance.

```{r}
library(forecast)
folds = make_folds(
  AirPassengers,
  fold_fun = folds_rolling_origin,
  first_window = 36,
  validation_size = 24
)
fold = folds[[1]]

# function to calculate cross-validated squared error
cv_forecasts <- function(fold, data) {
  train_data <- training(data)
  valid_data <- validation(data)
  valid_size <- length(valid_data)

  train_ts <- ts(log10(train_data), frequency = 12)

  # borrowed from AirPassengers help
  arima_fit <- arima(
    train_ts,
    c(0, 1, 1),
    seasonal = list(order = c(0, 1, 1), period = 12)
  )
  raw_arima_pred <- predict(arima_fit, n.ahead = valid_size)
  arima_pred <- 10^raw_arima_pred$pred
  arima_MSE <- mean((arima_pred - valid_data)^2)

  # stl model
  stl_fit <- stlm(train_ts, s.window = 12)
  raw_stl_pred = forecast(stl_fit, h = valid_size)
  stl_pred <- 10^raw_stl_pred$mean
  stl_MSE <- mean((stl_pred - valid_data)^2)

  out <- list(
    mse = data.frame(fold = fold_index(), arima = arima_MSE, stl = stl_MSE)
  )
  return(out)
}

mses = cross_validate(
  cv_fun = cv_forecasts,
  folds = folds,
  data = AirPassengers
)$mse
colMeans(mses[, c("arima", "stl")])
```

---

## Session Information

```{r sessionInfo, echo=FALSE}
sessionInfo()
```

---

## References
