---
title: "Illustration with Nonlinear Effects"
package: multimedia
output: 
  rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Illustration with Nonlinear Effects}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
    fig.align = "center",
    collapse = TRUE,
    comment = "#>",
    warning = FALSE,
    message = FALSE,
    cache = FALSE,
    dev.args = list(bg = "transparent"),
    out.width = 600,
    crop = NULL
)

knitr::knit_hooks$set(output = multimedia::ansi_aware_handler)
suppressPackageStartupMessages(library(ggplot2))
options(
    ggplot2.discrete.colour = c(
        "#9491D9", "#F24405", "#3F8C61", "#8C2E62", "#F2B705", "#11A0D9"
    ),
    ggplot2.discrete.fill = c(
        "#9491D9", "#F24405", "#3F8C61", "#8C2E62", "#F2B705", "#11A0D9"
    ),
    ggplot2.continuous.colour = function(...) {
        scale_color_distiller(palette = "Spectral", ...)
    },
    ggplot2.continuous.fill = function(...) {
        scale_fill_distiller(palette = "Spectral", ...)
    },
    crayon.enabled = TRUE
)

th <- theme_classic() +
    theme(
        panel.background = element_rect(fill = "transparent"),
        strip.background = element_rect(fill = "transparent"),
        plot.background = element_rect(fill = "transparent", color = NA),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        legend.background = element_rect(fill = "transparent"),
        legend.box.background = element_rect(fill = "transparent"),
        legend.position = "bottom"
    )
theme_set(th)
```

This vignette gives a simplified example similar to the "Quick Start with Random
Data," except in this simulated dataset, there are true direct and indirect
effects. The data are low dimensional, which also makes it easy to visualize our
model's behavior and the effect of various alterations. Before discussing data
and models, let's load libraries and define helper functions that will be used
for reorganizing the data later in this vignette. 

```{r setup}
library(glue)
library(tidyverse)
library(multimedia)
set.seed(20240111)

#' Helper to plot the real data
plot_exper <- function(fit, profile) {
    pivot_samples(fit, profile) |>
        ggplot() +
        geom_point(aes(mediator, value, col = factor(treatment))) +
        facet_grid(. ~ outcome) +
        theme(legend.position = "bottom")
}

#' Convert a SummarizedExperiment to long form
pivot_samples <- function(fit, profile) {
    exper_sample <- sample(fit, profile = profile)
    outcomes(exper_sample) |>
        bind_cols(mediators(exper_sample), t_outcome) |>
        pivot_longer(starts_with("outcome"), names_to = "outcome")
}
```

We will consdier a mediation analysis with two outcomes and a single mediator.
The toy dataset below is included in the package -- you can vary the number of
samples and true direct effect size by passing in arguments.

the treatment, mediator, and outcomes. We use
random spline functions for the direct effect.

```{r prepare_data}
xy_data <- demo_spline()
xy_long <- list(
    true = pivot_longer(xy_data, starts_with("outcome"), names_to = "outcome")
)
```

The block below estimates a random forest model on these randomly generated data
(after first converting them into an object of class `mediation_data`). We can
vary `ranger` parameters using the arguments to `rf_model()`.

```{r fit_model}
exper <- mediation_data(
    xy_data, starts_with("outcome"), "treatment", "mediator"
)

fit <- multimedia(
    exper,
    outcome_estimator = rf_model(num.trees = 1e2)
) |>
    estimate(exper)
```

Let's look at the estimated effects.

```{r effect_estimates}
direct_effect(fit) |>
    effect_summary()

indirect_overall(fit) |>
    effect_summary()
```

Mediation analysis depends on untestable assumptions on the absence of
confounders that influence both the mediators and outcome. For this reason, we
can run a sensitivity analysis: How strong would the confounder have to be
before we would observe a flipped indirect effect estimate?

```{r sensitivity_analysis}
confound_ix <- expand.grid(mediator = 1, outcome = 1:2)
# sensitivity_curve <- sensitivity(fit, exper, confound_ix)
sensitivity_curve <- read_csv("https://go.wisc.edu/0xcyr1")
plot_sensitivity(sensitivity_curve)
```

```{r}
perturb <- matrix(
    c(
        0, 3, 0,
        3, 0, 0,
        0, 0, 0
    ),
    nrow = 3, byrow = TRUE
)
# sensitivity_curve <- sensitivity_perturb(fit, exper, perturb)
sensitivity_curve <- read_csv("https://go.wisc.edu/75mz1b")
plot_sensitivity(sensitivity_curve, x_var = "nu")
```

If we want to look at samples at new treatment assignments, we have to first
construct a new treatment profile. We can accomplish this using the
`setup_profile` object.

```{r}
t_mediator <- tibble(treatment = factor(rep(0:1, each = nrow(exper) / 2)))
t_outcome <- tibble(treatment = factor(rep(0:1, each = nrow(exper) / 2)))
profile <- setup_profile(fit, t_mediator, t_outcome)
xy_long[["fitted"]] <- pivot_samples(fit, profile)
```

Let's estimate three altered models. The first two nullify the effects in the
mediation and outcome models, respectively. The third assumes a linear, rather
than random forest, outcome model.

```{r alteration}
altered_m <- nullify(fit, "T->M") |>
    estimate(exper)
altered_ty <- nullify(fit, "T->Y") |>
    estimate(exper)

fit_lm <- multimedia(
    exper,
    outcome_estimator = lm_model()
) |>
    estimate(exper)
```

We can sample from these models and organize the results into one long
data.frame. The function `pivot_samples` is defined at the top of the script.

```{r reshape_altered}
pretty_labels <- c(
    "Original Data", "RF (Full)", "RF (T-\\->M)", "RF (T-\\->Y)", "Linear Model"
)

xy_long <- c(
    xy_long,
    list(
        altered_m = pivot_samples(altered_m, profile),
        altered_ty = pivot_samples(altered_ty, profile),
        linear = pivot_samples(fit_lm, profile)
    )
) |>
    bind_rows(.id = "fit_type") |>
    mutate(
        fit_type = case_when(
            fit_type == "linear" ~ "Linear Model",
            fit_type == "true" ~ "Original Data",
            fit_type == "fitted" ~ "RF (Full)",
            fit_type == "altered_ty" ~ "RF (T-\\->Y)",
            fit_type == "altered_m" ~ "RF (T-\\->M)"
        ),
        fit_type = factor(fit_type, levels = pretty_labels),
        outcome = case_when(
            outcome == "outcome_1" ~ "Y[1]",
            outcome == "outcome_2" ~ "Y[2]"
        )
    )
```

The figure below compares the original data and model with the various nullified
versions, using the combined `xy_long` dataset constructed above.

```{r visualize_long, fig.width = 8, fig.height = 3}
xy_long |>
    sample_frac(size = 0.1) |>
    ggplot(aes(mediator, col = treatment)) +
    geom_point(aes(y = value), size = 0.4, alpha = 0.9) +
    geom_rug(alpha = 0.99, linewidth = 0.1) +
    facet_grid(outcome ~ fit_type) +
    labs(x = expression("Mediator M"), y = "Outcome", col = "Treatment") +
    theme(
        strip.text = element_text(size = 11),
        axis.title = element_text(size = 14),
        legend.title = element_text(size = 12)
    )
```

We can calculate bootstrap confidence intervals for the direct and overall
indirect effects using the block below.

```{r bootstraps}
fs <- list(direct_effect = direct_effect, indirect_overall = indirect_overall)
# bootstraps <- bootstrap(fit, exper, fs = fs)
bootstraps <- readRDS(url("https://go.wisc.edu/977l04"))

ggplot(bootstraps$direct_effect) +
    geom_histogram(aes(direct_effect)) +
    facet_wrap(~outcome)

ggplot(bootstraps$indirect_overall) +
    geom_histogram(aes(indirect_effect)) +
    facet_wrap(~outcome)
```

```{r}
sessionInfo()
```
