---
title: "Extending lolR with Arbitrary Classification Algorithms"
author: "Eric Bridgeford"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{extend_classification}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

# Writing New Classification Algorithms

The `lolR` package makes it easy for users to write their own classification algorithms for cross-validation.

# Writing a Classification Method

For example, consider the below classification algorithm built into the platform, the `nearestCentroid` classifier:

```
#' Nearest Centroid Classifier Training
#'
#' A function that trains a classifier based on the nearest centroid.
#' @param X \code{[n, d]} the data with \code{n} samples in \code{d} dimensions.
#' @param Y \code{[n]} the labels of the \code{n} samples.
#' @param ... optional args.
#' @return A list of class \code{nearestCentroid}, with the following attributes:
#' \item{centroids}{\code{[K, d]} the centroids of each class with \code{K}  classes in \code{d} dimensions.}
#' \item{ylabs}{\code{[K]} the ylabels for each of the \code{K} unique classes, ordered.}
#' \item{priors}{\code{[K]} the priors for each of the \code{K} classes.}
#' @author Eric Bridgeford
#'
#' @examples
#' library(lolR)
#' data <- lol.sims.rtrunk(n=200, d=30)  # 200 examples of 30 dimensions
#' X <- data$X; Y <- data$Y
#' model <- lol.classify.nearestCentroid(X, Y)
#' @export
lol.classify.nearestCentroid <- function(X, Y, ...) {
  # class data
  classdat <- lol.utils.info(X, Y)
  priors <- classdat$priors; centroids <- t(classdat$centroids)
  K <- classdat$K; ylabs <- classdat$ylabs
  model <-  list(centroids=centroids, ylabs=ylabs, priors=priors)
  return(structure(model, class="nearestCentroid"))
}
```

As we can see in the above segment, the function `lol.classify.nearestCentroid` returns a list of parameters for the `nearestCentroid` model. To use many of the `lol` functionality, researchers can trivially write a `classification` method following the below spec:

```
Inputs:
keyworded arguments for:
- X: a [n, d] data matrix with n samples in d dimensions.
- Y: a [n] vector of class labels for each sample.
- <additional-arguments>: additional arguments and hyperparameters your algorithm requires.
Outputs:
a list containing the following:
- <param1>: the first parameter of your model required for prediction.
- <param2>: the second parameter of your model required for prediction.
- ...: additional outputs you might need.
```

For example, my classifier takes the following arguments:

```
Inputs:
keyworded arguments for:
- X: a [n, d] data matrix with n samples in d dimensions.
- Y: a [n] vector of class labels for each sample.
Outputs:
a list containing the following:
- centroids: a [K, d] the centroids for each of the K classes.
- ylabs: [K] the label names associated with each of the K classes.
- priors:  [K] the priors for each of the K classes.
```

Note that the inputs MUST be named `X, Y`.

Your classifier will produce results as follows:

```
# given X, Y your data matrix and class labels as above
model <- lol.classify.nearestCentroid(X, Y)
```

# Writing a prediction method

To use the `lol.xval.eval`, your classification technique must be compatible with the `S3` method `stats::predict`. Below is an example of the prediction method for the `nearestCentroid` classifier shown above:

```
#' Nearest Centroid Classifier Prediction
#'
#' A function that predicts the class of points based on the nearest centroid
#' @param object An object of class \code{nearestCentroid}, with the following attributes:
#' \itemize{
#' \item{centroids}{\code{[K, d]} the centroids of each class with \code{K} classes in \code{d} dimensions.}
#' \item{ylabs}{\code{[K]} the ylabels for each of the \code{K} unique classes, ordered.}
#' \item{priors}{\code{[K]} the priors for each of the \code{K} classes.}
#' }
#' @param X \code{[n, d]} the data to classify with \code{n} samples in \code{d} dimensions.
#' @param ... optional args.
#' @return Yhat \code{[n]} the predicted class of each of the \code{n} data point in \code{X}.
#' @author Eric Bridgeford
#'
#' @examples
#' library(lolR)
#' data <- lol.sims.rtrunk(n=200, d=30)  # 200 examples of 30 dimensions
#' X <- data$X; Y <- data$Y
#' model <- lol.classify.nearestCentroid(X, Y)
#' Yh <- predict(model, X)
#' @export
predict.nearestCentroid <- function(object, X, ...) {
  K <- length(object$ylabs); n <-  dim(X)[1]
  dists <- array(0, dim=c(n, K))
  for (i in 1:n) {
    for (j in 1:K) {
      dists[i, j] <- sqrt(sum((X[i,] - object$centroids[j,])^2))
    }
  }
  Yass <- apply(dists, c(1), which.min)
  Yhat <- object$ylabs[Yass]
  return(Yhat)
}
```

As we can see, the `predict.nearestCentroid` prediction takes as arguments an `object` input, and a data matrix `X` of points to classify. To be compatible with `lol.xval.eval`, your method should obey the following `spec`:

```
Inputs:
- object: a list containing the parameters required by your model for prediction. This is required by stats::predict.
- X: a [n, d] data matrix with n samples in d dimensions to predict.
Outputs:
A list containing the following (At least one of <your-output#> should be labels for the n samples):
- <your-output1>: the first output of your classification technique.
- <your-output2>: the second output of your classification technique.
- ... additional outputs you may want.
OR
- <your-prediction-labels>: [n] the prediction labels for each of the n samples.
```

For example, my prediction can be follows the following API:

```
Inputs:
- object: a list containing the parameters required by your model for prediction. This is required by stats::predict.
- X: a [n, d] data matrix with n samples in d dimensions to predict.
Outputs:
- Yhat: [n] the prediction labels for each of the n samples.
```

At least one of the outputs of your prediction method should contain the prediction labels. In my above example, I simply return the labels themselves, but you may want to return a list where one of the outputs are the prediction labels.

# Using your Classification Technique for Embedding Evaluation

If your algorithm follows the above spec, you can easily use it with the `lol.xval.eval` for classification accuracy of your embedding algorithm. Having the algorithm and its prediction technique sourced:

```
classifier = <your-classifier>
# if your classifier requires additional hyperparameters, a keyworded list
# conaining additional arguments to train your classifier
classifier.opts = list(<additional-arguments>)
# if  your classifier requires no additional hyperparameters
classifier.opts = NaN
# if your classifier prediction returns a list containing the class labels
classifier.return = <return-labels-argname-string>
# if your classifier prediction returns only the class labels
classifier.return = NaN
```

For example, my algorithm can be be set up as follows:

```
classifier = lol.classify.nearestCentroid
classifier.opts = NaN  # my classifier takes no additional arguments, so NaN
classifier.return = NaN  # my classifier returns only the prediction labels, so NaN
```

The algorithm can then be incorporated as a classification technique to evaluate prediction accuracy after performing an embedding:

```
# given data X, Y as above
xval.out <- lol.xval.eval(X, Y, alg=<your-algorithm>, alg.opts=<your-algorithm-opts>,
                          alg.return = <your-algorithm-embedding-matrix>,
                          classifier=classifier, classifier.opts=classifier.opts,
                          classifier.return=classifier.return, k=<k>)
```

See the tutorial vignette `extend_embedding` for details on how to specify `alg`, `alg.opts`, and `alg.return` for a custom embedding algorithm.

Note for instance that the `randomForest` package includes the `rf` classification technique that is compabible with this spec:

```
require(randomForest)
classifier=randomForest
# use the randomForest classifier with the similarity matrix argument
classifier.opts = list(prox=TRUE)
classifier.return = NaN  # predict.randomForest returns only the prediction labels
xval.out <- lol.xval(X, Y, alg=<your-algorithm>, alg.opts=<your-algorithm-opts>,
                     alg.return = <your-algorithm-embedding-matrix>,
                     classifier=classifier, classifier.opts=classifier.opts,
                     classifier.return=classifier.return, k=<k>)
```

Now, you should be able to use your user-defined classification technique, or external classification techniques implementing the `S3` method `stats::predict`, with the `lol` package.
