## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(luz)
library(torch)

## ----eval = FALSE-------------------------------------------------------------
# net <- nn_module(
#   "Net",
#   initialize = function(num_class) {
#     self$conv1 <- nn_conv2d(1, 32, 3, 1)
#     self$conv2 <- nn_conv2d(32, 64, 3, 1)
#     self$dropout1 <- nn_dropout2d(0.25)
#     self$dropout2 <- nn_dropout2d(0.5)
#     self$fc1 <- nn_linear(9216, 128)
#     self$fc2 <- nn_linear(128, num_class)
#   },
#   forward = function(x) {
#     x <- self$conv1(x)
#     x <- nnf_relu(x)
#     x <- self$conv2(x)
#     x <- nnf_relu(x)
#     x <- nnf_max_pool2d(x, 2)
#     x <- self$dropout1(x)
#     x <- torch_flatten(x, start_dim = 2)
#     x <- self$fc1(x)
#     x <- nnf_relu(x)
#     x <- self$dropout2(x)
#     x <- self$fc2(x)
#     x
#   }
# )

## ----eval = FALSE-------------------------------------------------------------
# fitted <- net %>%
#   setup(
#     loss = nn_cross_entropy_loss(),
#     optimizer = optim_adam,
#     metrics = list(
#       luz_metric_accuracy
#     )
#   ) %>%
#   set_hparams(num_class = 10) %>%
#   set_opt_hparams(lr = 0.003) %>%
#   fit(train_dl, epochs = 10, valid_data = test_dl)

## ----eval = FALSE-------------------------------------------------------------
# predictions <- predict(fitted, test_dl)

## ----eval = FALSE-------------------------------------------------------------
# # -> Initialize objects: model, optimizers.
# # -> Select fitting device.
# # -> Move data, model, optimizers to the selected device.
# # -> Start training
# for (epoch in 1:epochs) {
#   # -> Training procedure
#   for (batch in train_dl) {
#     # -> Calculate model `forward` method.
#     # -> Calculate the loss
#     # -> Update weights
#     # -> Update metrics and tracking loss
#   }
#   # -> Validation procedure
#   for (batch in valid_dl) {
#     # -> Calculate model `forward` method.
#     # -> Calculate the loss
#     # -> Update metrics and tracking loss
#   }
# }
# # -> End training

## ----eval=FALSE---------------------------------------------------------------
# fitted <- net %>%
#   setup(
#     ...
#     metrics = list(
#       luz_metric_accuracy
#     )
#   ) %>%
#   fit(...)

## ----eval = FALSE-------------------------------------------------------------
# luz_metric_accuracy <- luz_metric(
#   # An abbreviation to be shown in progress bars, or
#   # when printing progress
#   abbrev = "Acc",
#   # Initial setup for the metric. Metrics are initialized
#   # every epoch, for both training and validation
#   initialize = function() {
#     self$correct <- 0
#     self$total <- 0
#   },
#   # Run at every training or validation step and updates
#   # the internal state. The update function takes `preds`
#   # and `target` as parameters.
#   update = function(preds, target) {
#     pred <- torch::torch_argmax(preds, dim = 2)
#     self$correct <- self$correct + (pred == target)$
#       to(dtype = torch::torch_float())$
#       sum()$
#       item()
#     self$total <- self$total + pred$numel()
#   },
#   # Use the internal state to query the metric value
#   compute = function() {
#     self$correct/self$total
#   }
# )

## ----include=FALSE, eval = torch::torch_is_installed()------------------------
library(luz)
torch::torch_manual_seed(1)
get_model <- function() {
  torch::nn_module(
    initialize = function(input_size, output_size) {
      self$fc <- torch::nn_linear(prod(input_size), prod(output_size))
      self$output_size <- output_size
    },
    forward = function(x) {
      out <- x %>%
        torch::torch_flatten(start_dim = 2) %>%
        self$fc()
      out$view(c(x$shape[1], self$output_size))
    }
  )
}

model <- get_model()
model <- model %>%
  setup(
    loss = torch::nn_mse_loss(),
    optimizer = torch::optim_adam,
    metrics = list(
      luz_metric_mae(),
      luz_metric_mse(),
      luz_metric_rmse()
    )
  ) %>%
  set_hparams(input_size = 10, output_size = 1) %>%
  set_opt_hparams(lr = 0.001)

x <- list(torch::torch_randn(100,10), torch::torch_randn(100, 1))

fitted <- model %>% fit(
  x,
  epochs = 1,
  verbose = FALSE,
  dataloader_options = list(batch_size = 2, shuffle = FALSE)
)

evaluation <- fitted %>% evaluate(data = x)

## ----eval = FALSE-------------------------------------------------------------
# evaluation <- fitted %>% evaluate(data = valid_dl)
# metrics <- get_metrics(evaluation)
# print(evaluation)

## ----echo=FALSE, eval=torch::torch_is_installed()-----------------------------
options(cli.unicode = FALSE)
metrics <- get_metrics(evaluation)
print(evaluation)

## ----eval = FALSE-------------------------------------------------------------
# print_callback <- luz_callback(
#   name = "print_callback",
#   initialize = function(message) {
#     self$message <- message
#   },
#   on_train_batch_end = function() {
#     cat("Iteration ", ctx$iter, "\n")
#   },
#   on_epoch_end = function() {
#     cat(self$message, "\n")
#   }
# )

## ----eval = FALSE-------------------------------------------------------------
# fitted <- net %>%
#   setup(...) %>%
#   fit(..., callbacks = list(
#     print_callback(message = "Done!")
#   ))

