---
title: "Tidymodels Workflow with Functional Keras Models (Multi-Input)"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Tidymodels Workflow with Functional Keras Models (Multi-Input)}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---



## Introduction

This vignette demonstrates a complete `tidymodels` workflow for a regression task using a Keras functional model defined with `kerasnip`. We will use the Ames Housing dataset to predict house prices. A key feature of this example is the use of a multi-input Keras model, where numerical and categorical features are processed through separate input branches.

`kerasnip` allows you to define complex Keras architectures, including those with multiple inputs, and integrate them seamlessly into the `tidymodels` ecosystem for robust modeling and tuning.

## Setup

First, we load the necessary packages.


``` r
library(kerasnip)
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────────────────────────────────────── tidymodels 1.5.0 ──
#> ✔ broom        1.0.12     ✔ recipes      1.3.2 
#> ✔ dials        1.4.3      ✔ rsample      1.3.2 
#> ✔ dplyr        1.2.1      ✔ tailor       0.1.0 
#> ✔ ggplot2      4.0.3      ✔ tidyr        1.3.2 
#> ✔ infer        1.1.0      ✔ tune         2.1.0 
#> ✔ modeldata    1.5.1      ✔ workflows    1.3.0 
#> ✔ parsnip      1.5.0      ✔ workflowsets 1.1.1 
#> ✔ purrr        1.2.2      ✔ yardstick    1.4.0
#> ── Conflicts ───────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter()  masks stats::filter()
#> ✖ dplyr::lag()     masks stats::lag()
#> ✖ recipes::step()  masks stats::step()
library(keras3)
#> 
#> Attaching package: 'keras3'
#> The following object is masked from 'package:yardstick':
#> 
#>     get_weights
#> The following object is masked from 'package:infer':
#> 
#>     generate
library(dplyr) # For data manipulation
library(ggplot2) # For plotting
library(future) # For parallel processing
#> 
#> Attaching package: 'future'
#> The following object is masked from 'package:keras3':
#> 
#>     %<-%
library(finetune) # For racing
```

## Data Preparation

We'll use the Ames Housing dataset, which is available in the `modeldata` package. We will then split the data into training and testing sets.


``` r
# Select relevant columns and remove rows with missing values
ames_df <- ames |>
  select(
    Sale_Price,
    Gr_Liv_Area,
    Year_Built,
    Neighborhood,
    Bldg_Type,
    Overall_Cond,
    Total_Bsmt_SF,
    contains("SF")
  ) |>
  na.omit()

# Split data into training and testing sets
set.seed(123)
ames_split <- initial_split(ames_df, prop = 0.8, strata = Sale_Price)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)

# Create cross-validation folds for tuning
ames_folds <- vfold_cv(ames_train, v = 5, strata = Sale_Price)
```

## Recipe for Preprocessing

We will create a `recipes` object to preprocess our data. This recipe will:
*   Predict `Sale_Price` using all other variables.
*   Normalize all numerical predictors.
*   Create dummy variables for categorical predictors.
*   Collapse each group of predictors into a single matrix column using `step_collapse()`.

This final step is crucial for the multi-input Keras model, as the `kerasnip` functional API expects a list of matrices for multiple inputs, where each matrix corresponds to a distinct input layer.


``` r
ames_recipe <- recipe(Sale_Price ~ ., data = ames_train) |>
  step_normalize(all_numeric_predictors()) |>
  step_collapse(all_numeric_predictors(), new_col = "numerical_input") |>
  step_dummy(Neighborhood) |>
  step_collapse(starts_with("Neighborhood"), new_col = "neighborhood_input") |>
  step_dummy(Bldg_Type) |>
  step_collapse(starts_with("Bldg_Type"), new_col = "bldg_input") |>
  step_dummy(Overall_Cond) |>
  step_collapse(starts_with("Overall_Cond"), new_col = "condition_input")
```

## Define Keras Functional Model with `kerasnip`

Now, we define our Keras functional model using `kerasnip`'s layer blocks. This model will have four distinct input layers: one for numerical features and three for categorical features. These branches will be processed separately and then concatenated before the final output layer.


``` r
# Define layer blocks for multi-input functional model

# Input blocks for numerical and categorical features
input_numerical <- function(input_shape) {
  layer_input(shape = input_shape, name = "numerical_input")
}

input_neighborhood <- function(input_shape) {
  layer_input(shape = input_shape, name = "neighborhood_input")
}

input_bldg <- function(input_shape) {
  layer_input(shape = input_shape, name = "bldg_input")
}

input_condition <- function(input_shape) {
  layer_input(shape = input_shape, name = "condition_input")
}

# Processing blocks for each input type
dense_numerical <- function(tensor, units = 32, activation = "relu") {
  tensor |>
    layer_dense(units = units, activation = activation)
}

dense_categorical <- function(tensor, units = 16, activation = "relu") {
  tensor |>
    layer_dense(units = units, activation = activation)
}

# Concatenation block
concatenate_features <- function(numeric, neighborhood, bldg, condition) {
  layer_concatenate(list(numeric, neighborhood, bldg, condition))
}

# Output block for regression
output_regression <- function(tensor) {
  layer_dense(tensor, units = 1, name = "output")
}

# Create the kerasnip model specification function
create_keras_functional_spec(
  model_name = "ames_functional_mlp",
  layer_blocks = list(
    numerical_input = input_numerical,
    neighborhood_input = input_neighborhood,
    bldg_input = input_bldg,
    condition_input = input_condition,
    processed_numerical = inp_spec(dense_numerical, "numerical_input"),
    processed_neighborhood = inp_spec(dense_categorical, "neighborhood_input"),
    processed_bldg = inp_spec(dense_categorical, "bldg_input"),
    processed_condition = inp_spec(dense_categorical, "condition_input"),
    combined_features = inp_spec(
      concatenate_features,
      c(
        numeric = "processed_numerical",
        neighborhood = "processed_neighborhood",
        bldg = "processed_bldg",
        condition = "processed_condition"
      )
    ),
    output = inp_spec(output_regression, "combined_features")
  ),
  mode = "regression"
)
```

## Model Specification

We'll define our `ames_functional_mlp` model specification and set some hyperparameters to `tune()`. Note how the arguments are prefixed with their corresponding block names (e.g., `processed_numerical_units`).


``` r
# Define the tunable model specification
functional_mlp_spec <- ames_functional_mlp(
  # Tunable parameters for numerical branch
  processed_numerical_units = tune(),
  # Tunable parameters for categorical branch
  processed_neighborhood_units = tune(),
  processed_bldg_units = tune(),
  processed_condition_units = tune(),
  # Fixed compilation and fitting parameters
  compile_loss = "mean_squared_error",
  compile_optimizer = "adam",
  compile_metrics = c("mean_absolute_error"),
  fit_epochs = 50,
  fit_batch_size = 32,
  fit_validation_split = 0.2,
  fit_callbacks = list(
    callback_early_stopping(monitor = "val_loss", patience = 5)
  )
) |>
  set_engine("keras")

print(functional_mlp_spec)
#> ames functional mlp Model Specification (regression)
#> 
#> Main Arguments:
#>   num_numerical_input = structure(list(), class = "rlang_zap")
#>   num_neighborhood_input = structure(list(), class = "rlang_zap")
#>   num_bldg_input = structure(list(), class = "rlang_zap")
#>   num_condition_input = structure(list(), class = "rlang_zap")
#>   num_processed_numerical = structure(list(), class = "rlang_zap")
#>   num_processed_neighborhood = structure(list(), class = "rlang_zap")
#>   num_processed_bldg = structure(list(), class = "rlang_zap")
#>   num_processed_condition = structure(list(), class = "rlang_zap")
#>   num_combined_features = structure(list(), class = "rlang_zap")
#>   num_output = structure(list(), class = "rlang_zap")
#>   processed_numerical_units = tune()
#>   processed_numerical_activation = structure(list(), class = "rlang_zap")
#>   processed_neighborhood_units = tune()
#>   processed_neighborhood_activation = structure(list(), class = "rlang_zap")
#>   processed_bldg_units = tune()
#>   processed_bldg_activation = structure(list(), class = "rlang_zap")
#>   processed_condition_units = tune()
#>   processed_condition_activation = structure(list(), class = "rlang_zap")
#>   learn_rate = structure(list(), class = "rlang_zap")
#>   fit_batch_size = 32
#>   fit_epochs = 50
#>   fit_callbacks = list(callback_early_stopping(monitor = "val_loss", patience = 5))
#>   fit_validation_split = 0.2
#>   fit_validation_data = structure(list(), class = "rlang_zap")
#>   fit_shuffle = structure(list(), class = "rlang_zap")
#>   fit_class_weight = structure(list(), class = "rlang_zap")
#>   fit_sample_weight = structure(list(), class = "rlang_zap")
#>   fit_initial_epoch = structure(list(), class = "rlang_zap")
#>   fit_steps_per_epoch = structure(list(), class = "rlang_zap")
#>   fit_validation_steps = structure(list(), class = "rlang_zap")
#>   fit_validation_batch_size = structure(list(), class = "rlang_zap")
#>   fit_validation_freq = structure(list(), class = "rlang_zap")
#>   fit_verbose = structure(list(), class = "rlang_zap")
#>   fit_view_metrics = structure(list(), class = "rlang_zap")
#>   compile_optimizer = adam
#>   compile_loss = mean_squared_error
#>   compile_metrics = c("mean_absolute_error")
#>   compile_loss_weights = structure(list(), class = "rlang_zap")
#>   compile_weighted_metrics = structure(list(), class = "rlang_zap")
#>   compile_run_eagerly = structure(list(), class = "rlang_zap")
#>   compile_steps_per_execution = structure(list(), class = "rlang_zap")
#>   compile_jit_compile = structure(list(), class = "rlang_zap")
#>   compile_auto_scale_loss = structure(list(), class = "rlang_zap")
#> 
#> Computational engine: keras
```

## Create Workflow

A `workflow` combines the recipe and the model specification. 


``` r
ames_wf <- workflow() |>
  add_recipe(ames_recipe) |>
  add_model(functional_mlp_spec)

print(ames_wf)
#> ══ Workflow ════════════════════════════════════════════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: ames_functional_mlp()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────────────────────────────────────────────
#> 8 Recipe Steps
#> 
#> • step_normalize()
#> • step_collapse()
#> • step_dummy()
#> • step_collapse()
#> • step_dummy()
#> • step_collapse()
#> • step_dummy()
#> • step_collapse()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────────────────────────────────────────────
#> ames functional mlp Model Specification (regression)
#> 
#> Main Arguments:
#>   num_numerical_input = structure(list(), class = "rlang_zap")
#>   num_neighborhood_input = structure(list(), class = "rlang_zap")
#>   num_bldg_input = structure(list(), class = "rlang_zap")
#>   num_condition_input = structure(list(), class = "rlang_zap")
#>   num_processed_numerical = structure(list(), class = "rlang_zap")
#>   num_processed_neighborhood = structure(list(), class = "rlang_zap")
#>   num_processed_bldg = structure(list(), class = "rlang_zap")
#>   num_processed_condition = structure(list(), class = "rlang_zap")
#>   num_combined_features = structure(list(), class = "rlang_zap")
#>   num_output = structure(list(), class = "rlang_zap")
#>   processed_numerical_units = tune()
#>   processed_numerical_activation = structure(list(), class = "rlang_zap")
#>   processed_neighborhood_units = tune()
#>   processed_neighborhood_activation = structure(list(), class = "rlang_zap")
#>   processed_bldg_units = tune()
#>   processed_bldg_activation = structure(list(), class = "rlang_zap")
#>   processed_condition_units = tune()
#>   processed_condition_activation = structure(list(), class = "rlang_zap")
#>   learn_rate = structure(list(), class = "rlang_zap")
#>   fit_batch_size = 32
#>   fit_epochs = 50
#>   fit_callbacks = list(callback_early_stopping(monitor = "val_loss", patience = 5))
#>   fit_validation_split = 0.2
#>   fit_validation_data = structure(list(), class = "rlang_zap")
#>   fit_shuffle = structure(list(), class = "rlang_zap")
#>   fit_class_weight = structure(list(), class = "rlang_zap")
#>   fit_sample_weight = structure(list(), class = "rlang_zap")
#>   fit_initial_epoch = structure(list(), class = "rlang_zap")
#>   fit_steps_per_epoch = structure(list(), class = "rlang_zap")
#>   fit_validation_steps = structure(list(), class = "rlang_zap")
#>   fit_validation_batch_size = structure(list(), class = "rlang_zap")
#>   fit_validation_freq = structure(list(), class = "rlang_zap")
#>   fit_verbose = structure(list(), class = "rlang_zap")
#>   fit_view_metrics = structure(list(), class = "rlang_zap")
#>   compile_optimizer = adam
#>   compile_loss = mean_squared_error
#>   compile_metrics = c("mean_absolute_error")
#>   compile_loss_weights = structure(list(), class = "rlang_zap")
#>   compile_weighted_metrics = structure(list(), class = "rlang_zap")
#>   compile_run_eagerly = structure(list(), class = "rlang_zap")
#>   compile_steps_per_execution = structure(list(), class = "rlang_zap")
#>   compile_jit_compile = structure(list(), class = "rlang_zap")
#>   compile_auto_scale_loss = structure(list(), class = "rlang_zap")
#> 
#> Computational engine: keras
```

## Define Tuning Grid

We will create a regular grid for our hyperparameters.


``` r
# Define the tuning grid
params <- extract_parameter_set_dials(ames_wf) |>
  update(
    processed_numerical_units = hidden_units(range = c(32, 128)),
    processed_neighborhood_units = hidden_units(range = c(16, 64)),
    processed_bldg_units = hidden_units(range = c(16, 64)),
    processed_condition_units = hidden_units(range = c(16, 64))
  )
functional_mlp_grid <- grid_regular(params, levels = 3)

print(functional_mlp_grid)
#> # A tibble: 81 × 4
#>    processed_numerical_units processed_neighborhood_units processed_bldg_units processed_condition_units
#>                        <int>                        <int>                <int>                     <int>
#>  1                        32                           16                   16                        16
#>  2                        80                           16                   16                        16
#>  3                       128                           16                   16                        16
#>  4                        32                           40                   16                        16
#>  5                        80                           40                   16                        16
#>  6                       128                           40                   16                        16
#>  7                        32                           64                   16                        16
#>  8                        80                           64                   16                        16
#>  9                       128                           64                   16                        16
#> 10                        32                           16                   40                        16
#> # ℹ 71 more rows
```

## Tune Model

Now, we'll use `tune_race_anova()` to perform cross-validation and find the best hyperparameters.


``` r
# Note: Parallel processing with `plan(multisession)` is currently not working
# with Keras models due to backend conflicts

set.seed(123)
ames_tune_results <- tune_race_anova(
  ames_wf,
  resamples = ames_folds,
  grid = functional_mlp_grid,
  metrics = metric_set(rmse, mae, rsq),
  control = control_race(save_pred = TRUE, save_workflow = TRUE)
)
```

## Inspect Tuning Results

We can inspect the tuning results to see which hyperparameter combinations performed best.


``` r
# Show the best performing models based on RMSE
show_best(ames_tune_results, metric = "rmse", n = 5)
#> # A tibble: 2 × 10
#>   processed_numerical_units processed_neighborho…¹ processed_bldg_units processed_condition_…² .metric .estimator   mean
#>                       <int>                  <int>                <int>                  <int> <chr>   <chr>       <dbl>
#> 1                       128                     64                   64                     64 rmse    standard   53524.
#> 2                       128                     64                   40                     64 rmse    standard   54215.
#> # ℹ abbreviated names: ¹​processed_neighborhood_units, ²​processed_condition_units
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>

# Autoplot the results
# Currently does not work due to a label issue: autoplot(ames_tune_results)

# Select the best hyperparameters
best_functional_mlp_params <- select_best(ames_tune_results, metric = "rmse")
print(best_functional_mlp_params)
#> # A tibble: 1 × 5
#>   processed_numerical_units processed_neighborhood_units processed_bldg_units processed_condition_units .config         
#>                       <int>                        <int>                <int>                     <int> <chr>           
#> 1                       128                           64                   64                        64 pre0_mod81_post0
```

## Finalize Workflow and Fit Model

Once we have the best hyperparameters, we finalize the workflow and fit the model on the entire training dataset.


``` r
# Finalize the workflow with the best hyperparameters
final_ames_wf <- finalize_workflow(ames_wf, best_functional_mlp_params)

# Fit the final model on the full training data
final_ames_fit <- fit(final_ames_wf, data = ames_train)

print(final_ames_fit)
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: ames_functional_mlp()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────────────────────────────────────────────
#> 8 Recipe Steps
#> 
#> • step_normalize()
#> • step_collapse()
#> • step_dummy()
#> • step_collapse()
#> • step_dummy()
#> • step_collapse()
#> • step_dummy()
#> • step_collapse()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────────────────────────────────────────────
#> $fit
#> Model: "functional_262"
#> ┌───────────────────────────────────┬──────────────────────────────┬───────────────────┬───────────────────────────────
#> │ Layer (type)                      │ Output Shape                 │           Param # │ Connected to                  
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ numerical_input (InputLayer)      │ (None, 1, 10)                │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ neighborhood_input (InputLayer)   │ (None, 1, 28)                │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ bldg_input (InputLayer)           │ (None, 1, 4)                 │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ condition_input (InputLayer)      │ (None, 1, 9)                 │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1033 (Dense)                │ (None, 1, 128)               │             1,408 │ numerical_input[0][0]         
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1034 (Dense)                │ (None, 1, 64)                │             1,856 │ neighborhood_input[0][0]      
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1035 (Dense)                │ (None, 1, 64)                │               320 │ bldg_input[0][0]              
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1036 (Dense)                │ (None, 1, 64)                │               640 │ condition_input[0][0]         
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ concatenate_258 (Concatenate)     │ (None, 1, 320)               │                 0 │ dense_1033[0][0],             
#> │                                   │                              │                   │ dense_1034[0][0],             
#> │                                   │                              │                   │ dense_1035[0][0],             
#> │                                   │                              │                   │ dense_1036[0][0]              
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ output (Dense)                    │ (None, 1, 1)                 │               321 │ concatenate_258[0][0]         
#> └───────────────────────────────────┴──────────────────────────────┴───────────────────┴───────────────────────────────
#>  Total params: 13,637 (53.27 KB)
#>  Trainable params: 4,545 (17.75 KB)
#>  Non-trainable params: 0 (0.00 B)
#>  Optimizer params: 9,092 (35.52 KB)
#> 
#> $keras_bytes
#>     [1] 50 4b 03 04 14 00 00 00 00 00 00 00 21 00 39 22 4e 35 40 00 00 00 40 00 00 00 0d 00 00 00 6d 65 74 61 64 61 74
#>    [38] 61 2e 6a 73 6f 6e 7b 22 6b 65 72 61 73 5f 76 65 72 73 69 6f 6e 22 3a 20 22 33 2e 31 34 2e 30 22 2c 20 22 64 61
#>    [75] 74 65 5f 73 61 76 65 64 22 3a 20 22 32 30 32 36 2d 30 35 2d 30 31 40 31 32 3a 35 38 3a 31 31 22 7d 50 4b 03 04
#>   [112] 14 00 00 00 00 00 00 00 21 00 c3 b3 02 c8 41 20 00 00 41 20 00 00 0b 00 00 00 63 6f 6e 66 69 67 2e 6a 73 6f 6e
#>   [149] 7b 22 6d 6f 64 75 6c 65 22 3a 20 22 6b 65 72 61 73 2e 73 72 63 2e 6d 6f 64 65 6c 73 2e 66 75 6e 63 74 69 6f 6e
#>   [186] 61 6c 22 2c 20 22 63 6c 61 73 73 5f 6e 61 6d 65 22 3a 20 22 46 75 6e 63 74 69 6f 6e 61 6c 22 2c 20 22 63 6f 6e
#>   [223] 66 69 67 22 3a 20 7b 22 6e 61 6d 65 22 3a 20 22 66 75 6e 63 74 69 6f 6e 61 6c 5f 32 36 32 22 2c 20 22 74 72 61
#>   [260] 69 6e 61 62 6c 65 22 3a 20 74 72 75 65 2c 20 22 6c 61 79 65 72 73 22 3a 20 5b 7b 22 6d 6f 64 75 6c 65 22 3a 20
#>   [297] 22 6b 65 72 61 73 2e 6c 61 79 65 72 73 22 2c 20 22 63 6c 61 73 73 5f 6e 61 6d 65 22 3a 20 22 49 6e 70 75 74 4c
#>   [334] 61 79 65 72 22 2c 20 22 63 6f 6e 66 69 67 22 3a 20 7b 22 62 61 74 63 68 5f 73 68 61 70 65 22 3a 20 5b 6e 75 6c
#>   [371] 6c 2c 20 31 2c 20 31 30 5d 2c 20 22 64 74 79 70 65 22 3a 20 22 66 6c 6f 61 74 33 32 22 2c 20 22 73 70 61 72 73
#>   [408] 65 22 3a 20 66 61 6c 73 65 2c 20 22 72 61 67 67 65 64 22 3a 20 66 61 6c 73 65 2c 20 22 6e 61 6d 65 22 3a 20 22
#>   [445] 6e 75 6d 65 72 69 63 61 6c 5f 69 6e 70 75 74 22 2c 20 22 6f 70 74 69 6f 6e 61 6c 22 3a 20 66 61 6c 73 65 7d 2c
#>   [482] 20 22 72 65 67 69 73 74 65 72 65 64 5f 6e 61 6d 65 22 3a 20 6e 75 6c 6c 2c 20 22 6e 61 6d 65 22 3a 20 22 6e 75
#>   [519] 6d 65 72 69 63 61 6c 5f 69 6e 70 75 74 22 2c 20 22 69 6e 62 6f 75 6e 64 5f 6e 6f 64 65 73 22 3a 20 5b 5d 7d 2c
#>   [556] 20 7b 22 6d 6f 64 75 6c 65 22 3a 20 22 6b 65 72 61 73 2e 6c 61 79 65 72 73 22 2c 20 22 63 6c 61 73 73 5f 6e 61
#> 
#> ...
#> and 2790 more lines.
```

### Inspect Final Model

You can extract the underlying Keras model and its training history for further inspection.


``` r
# Extract the Keras model summary
final_ames_fit |>
  extract_fit_parsnip() |>
  extract_keras_model() |>
  summary()
#> Model: "functional_262"
#> ┌───────────────────────────────────┬──────────────────────────────┬───────────────────┬───────────────────────────────
#> │ Layer (type)                      │ Output Shape                 │           Param # │ Connected to                  
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ numerical_input (InputLayer)      │ (None, 1, 10)                │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ neighborhood_input (InputLayer)   │ (None, 1, 28)                │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ bldg_input (InputLayer)           │ (None, 1, 4)                 │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ condition_input (InputLayer)      │ (None, 1, 9)                 │                 0 │ -                             
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1033 (Dense)                │ (None, 1, 128)               │             1,408 │ numerical_input[0][0]         
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1034 (Dense)                │ (None, 1, 64)                │             1,856 │ neighborhood_input[0][0]      
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1035 (Dense)                │ (None, 1, 64)                │               320 │ bldg_input[0][0]              
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ dense_1036 (Dense)                │ (None, 1, 64)                │               640 │ condition_input[0][0]         
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ concatenate_258 (Concatenate)     │ (None, 1, 320)               │                 0 │ dense_1033[0][0],             
#> │                                   │                              │                   │ dense_1034[0][0],             
#> │                                   │                              │                   │ dense_1035[0][0],             
#> │                                   │                              │                   │ dense_1036[0][0]              
#> ├───────────────────────────────────┼──────────────────────────────┼───────────────────┼───────────────────────────────
#> │ output (Dense)                    │ (None, 1, 1)                 │               321 │ concatenate_258[0][0]         
#> └───────────────────────────────────┴──────────────────────────────┴───────────────────┴───────────────────────────────
#>  Total params: 13,637 (53.27 KB)
#>  Trainable params: 4,545 (17.75 KB)
#>  Non-trainable params: 0 (0.00 B)
#>  Optimizer params: 9,092 (35.52 KB)
```


``` r
# Plot the Keras model
final_ames_fit |>
  extract_fit_parsnip() |>
  extract_keras_model() |>
  plot(show_shapes = TRUE)
```

![Model](images/model_plot_shapes_wf.png){fig-alt="A picture showing the model shape"}


``` r
# Plot the training history
final_ames_fit |>
  extract_fit_parsnip() |>
  extract_keras_history() |>
  plot()
```

![plot of chunk inspect-final-keras-model-history](figure/inspect-final-keras-model-history-1.png)

## Make Predictions and Evaluate

Finally, we will make predictions on the test set and evaluate the model's performance.


``` r
# Make predictions on the test set
ames_test_pred <- predict(final_ames_fit, new_data = ames_test)
#> 19/19 - 0s - 10ms/step

# Combine predictions with actuals
ames_results <- tibble::tibble(
  Sale_Price = ames_test$Sale_Price,
  .pred = ames_test_pred$.pred
)

print(head(ames_results))
#> # A tibble: 6 × 2
#>   Sale_Price   .pred
#>        <int>   <dbl>
#> 1     189900 193909.
#> 2     195500 195484.
#> 3     236500 234049.
#> 4     212000 217096.
#> 5     210000 241706.
#> 6     142000 126019.

# Evaluate performance using yardstick metrics
metrics_results <- metric_set(
  rmse,
  mae,
  rsq
)(
  ames_results,
  truth = Sale_Price,
  estimate = .pred
)

print(metrics_results)
#> # A tibble: 3 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard   44687.   
#> 2 mae     standard   27910.   
#> 3 rsq     standard       0.767
```

## Saving and Reloading Your Model

`kerasnip` serializes the Keras model weights to bytes at fit time and stores them alongside the workflow object. This means plain `saveRDS()` / `readRDS()` **works out of the box** — the underlying Keras model is restored automatically the first time `predict()` is called on the reloaded object.


``` r
# Save the FINAL fitted workflow
saveRDS(final_ames_fit, "ames_model.rds")

# Reload — no extra steps needed
final_ames_fit_loaded <- readRDS("ames_model.rds")

# Make predictions again to prove it works
predict(final_ames_fit_loaded, new_data = ames_test) |> head()
#> 19/19 - 0s - 11ms/step
#> # A tibble: 6 × 1
#>     .pred
#>     <dbl>
#> 1 193909.
#> 2 195484.
#> 3 234049.
#> 4 217096.
#> 5 241706.
#> 6 126019.
```

If you need a fully self-contained bundle suitable for deployment with `vetiver` or other MLOps tools, use `bundle::bundle()` instead:


``` r
library(bundle)

# Save as a portable bundle
bundled <- bundle(final_ames_fit)
saveRDS(bundled, "ames_model_bundle.rds")

# Reload in any R session
library(kerasnip)
library(bundle)
final_ames_fit_loaded <- unbundle(readRDS("ames_model_bundle.rds"))

predict(final_ames_fit_loaded, new_data = ames_test) |> head()
#> 19/19 - 0s - 9ms/step
#> # A tibble: 6 × 1
#>     .pred
#>     <dbl>
#> 1 193909.
#> 2 195484.
#> 3 234049.
#> 4 217096.
#> 5 241706.
#> 6 126019.
```

See `vignette("saving_and_reloading")` for a detailed comparison of both approaches.
