## ----load_data----------------------------------------------------------------
data(mtcars)
head(mtcars)

## ----linear_mod---------------------------------------------------------------
lm_mod <- lm(mpg ~ ., data = mtcars)
summary(lm_mod)

## ----get_naive_error----------------------------------------------------------
err <- mean(resid(lm_mod)^2)

## ----define_fun_cv_lm---------------------------------------------------------
cv_lm <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # split up data into training and validation sets
  train_data <- training(data)
  valid_data <- validation(data)

  # fit linear model on training set and predict on validation set
  mod <- lm(as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # capture results to be returned as output
  out <- list(
    coef = data.frame(t(coef(mod))),
    SE = ((preds - valid_data[, out_var_ind])^2)
  )
  return(out)
}

## ----load_pkgs----------------------------------------------------------------
library(origami)
library(stringr) # used in defining the cv_lm function above

## ----cv_lm_resub--------------------------------------------------------------
# resubstitution estimate
resub <- make_folds(mtcars, fold_fun = folds_resubstitution)[[1]]
resub_results <- cv_lm(fold = resub, data = mtcars, reg_form = "mpg ~ .")
mean(resub_results$SE)

## ----cv_lm_cross_valdate------------------------------------------------------
# cross-validated estimate
folds <- make_folds(mtcars)
cvlm_results <- cross_validate(
  cv_fun = cv_lm,
  folds = folds,
  data = mtcars,
  reg_form = "mpg ~ ."
)
mean(cvlm_results$SE)

## ----cv_fun_randomForest------------------------------------------------------
cv_rf <- function(fold, data, reg_form) {
  # get name and index of outcome variable from regression formula
  out_var <- as.character(unlist(str_split(reg_form, " "))[1])
  out_var_ind <- as.numeric(which(colnames(data) == out_var))

  # define training and validation sets based on input object of class "folds"
  train_data <- training(data)
  valid_data <- validation(data)

  # fit Random Forest regression on training set and predict on holdout set
  mod <- randomForest(formula = as.formula(reg_form), data = train_data)
  preds <- predict(mod, newdata = valid_data)

  # define output object to be returned as list (for flexibility)
  out <- list(
    coef = data.frame(mod$coefs),
    SE = ((preds - valid_data[, out_var_ind])^2)
  )
  return(out)
}

## -----------------------------------------------------------------------------
library(randomForest)
folds <- make_folds(mtcars)
cvrf_results <- cross_validate(
  cv_fun = cv_rf,
  folds = folds,
  data = mtcars,
  reg_form = "mpg ~ ."
)
mean(cvrf_results$SE)

## -----------------------------------------------------------------------------
data(AirPassengers)
print(AirPassengers)

## -----------------------------------------------------------------------------
library(forecast)
folds = make_folds(
  AirPassengers,
  fold_fun = folds_rolling_origin,
  first_window = 36,
  validation_size = 24
)
fold = folds[[1]]

# function to calculate cross-validated squared error
cv_forecasts <- function(fold, data) {
  train_data <- training(data)
  valid_data <- validation(data)
  valid_size <- length(valid_data)

  train_ts <- ts(log10(train_data), frequency = 12)

  # borrowed from AirPassengers help
  arima_fit <- arima(
    train_ts,
    c(0, 1, 1),
    seasonal = list(order = c(0, 1, 1), period = 12)
  )
  raw_arima_pred <- predict(arima_fit, n.ahead = valid_size)
  arima_pred <- 10^raw_arima_pred$pred
  arima_MSE <- mean((arima_pred - valid_data)^2)

  # stl model
  stl_fit <- stlm(train_ts, s.window = 12)
  raw_stl_pred = forecast(stl_fit, h = valid_size)
  stl_pred <- 10^raw_stl_pred$mean
  stl_MSE <- mean((stl_pred - valid_data)^2)

  out <- list(
    mse = data.frame(fold = fold_index(), arima = arima_MSE, stl = stl_MSE)
  )
  return(out)
}

mses = cross_validate(
  cv_fun = cv_forecasts,
  folds = folds,
  data = AirPassengers
)$mse
colMeans(mses[, c("arima", "stl")])

## ----sessionInfo, echo=FALSE--------------------------------------------------
sessionInfo()

