## ----setup, echo = FALSE, results = "hide", message = FALSE, warning = FALSE----
suppressWarnings(RNGversion("3.5.2"))
options(width = 70)
library("partykit")
set.seed(290875)

## ----Titanic--------------------------------------------------------
data("Titanic", package = "datasets")
ttnc <- as.data.frame(Titanic)
ttnc <- ttnc[rep(1:nrow(ttnc), ttnc$Freq), 1:4]
names(ttnc)[2] <- "Gender"

## ----rpart----------------------------------------------------------
library("rpart")
(rp <- rpart(Survived ~ ., data = ttnc, model = TRUE))

## ----rpart-party----------------------------------------------------
(party_rp <- as.party(rp))

## ----rpart-plot-orig, fig.width = 10, fig.height = 6----------------
plot(rp)
text(rp)

## ----rpart-plot, fig.width = 10, fig.height = 6---------------------
plot(party_rp)

## ----rpart-pred-----------------------------------------------------
all.equal(predict(rp), predict(party_rp, type = "prob"), 
  check.attributes = FALSE)

## ----rpart-fitted---------------------------------------------------
str(fitted(party_rp))

## ----rpart-prob-----------------------------------------------------
prop.table(do.call("table", fitted(party_rp)), 1)

## ----J48------------------------------------------------------------
if (require("RWeka")) {
  j48 <- J48(Survived ~ ., data = ttnc)
} else {
  j48 <- rpart(Survived ~ ., data = ttnc)
}
print(j48)

## ----J48-party------------------------------------------------------
(party_j48 <- as.party(j48))

## ----J48-plot, fig.width = 15, fig.height = 9-----------------------
plot(party_j48)

## ----J48-pred-------------------------------------------------------
all.equal(predict(j48, type = "prob"), predict(party_j48, type = "prob"),
  check.attributes = FALSE)

## ----mytree-1, echo = TRUE------------------------------------------
findsplit <- function(response, data, weights, alpha = 0.01) {

  ## extract response values from data
  y <- factor(rep(data[[response]], weights))

  ## perform chi-squared test of y vs. x
  mychisqtest <- function(x) {
    x <- factor(x)
    if(length(levels(x)) < 2) return(NA)
    ct <- suppressWarnings(chisq.test(table(y, x), correct = FALSE))
    pchisq(ct$statistic, ct$parameter, log = TRUE, lower.tail = FALSE)
  }
  xselect <- which(names(data) != response)
  logp <- sapply(xselect, function(i) mychisqtest(rep(data[[i]], weights)))
  names(logp) <- names(data)[xselect]

  ## Bonferroni-adjusted p-value small enough?
  if(all(is.na(logp))) return(NULL)
  minp <- exp(min(logp, na.rm = TRUE))
  minp <- 1 - (1 - minp)^sum(!is.na(logp))
  if(minp > alpha) return(NULL)

  ## for selected variable, search for split minimizing p-value  
  xselect <- xselect[which.min(logp)]
  x <- rep(data[[xselect]], weights)

  ## set up all possible splits in two kid nodes
  lev <- levels(x[drop = TRUE])
  if(length(lev) == 2) {
    splitpoint <- lev[1]
  } else {
    comb <- do.call("c", lapply(1:(length(lev) - 2),
      function(x) combn(lev, x, simplify = FALSE)))
    xlogp <- sapply(comb, function(q) mychisqtest(x %in% q))
    splitpoint <- comb[[which.min(xlogp)]]
  }

  ## split into two groups (setting groups that do not occur to NA)
  splitindex <- !(levels(data[[xselect]]) %in% splitpoint)
  splitindex[!(levels(data[[xselect]]) %in% lev)] <- NA_integer_
  splitindex <- splitindex - min(splitindex, na.rm = TRUE) + 1L

  ## return split as partysplit object
  return(partysplit(varid = as.integer(xselect),
    index = splitindex,
    info = list(p.value = 1 - (1 - exp(logp))^sum(!is.na(logp)))))
}

## ----mytree-2, echo = TRUE------------------------------------------
growtree <- function(id = 1L, response, data, weights, minbucket = 30) {

  ## for less than 30 observations stop here
  if (sum(weights) < minbucket) return(partynode(id = id))

  ## find best split
  sp <- findsplit(response, data, weights)
  ## no split found, stop here
  if (is.null(sp)) return(partynode(id = id))

  ## actually split the data
  kidids <- kidids_split(sp, data = data)

  ## set up all daugther nodes
  kids <- vector(mode = "list", length = max(kidids, na.rm = TRUE))
  for (kidid in 1:length(kids)) {
  ## select observations for current node
  w <- weights
  w[kidids != kidid] <- 0
  ## get next node id
  if (kidid > 1) {
    myid <- max(nodeids(kids[[kidid - 1]]))
  } else {
    myid <- id
  }
  ## start recursion on this daugther node
  kids[[kidid]] <- growtree(id = as.integer(myid + 1), response, data, w)
  }

  ## return nodes
  return(partynode(id = as.integer(id), split = sp, kids = kids,
    info = list(p.value = min(info_split(sp)$p.value, na.rm = TRUE))))
}

## ----mytree-3, echo = TRUE------------------------------------------
mytree <- function(formula, data, weights = NULL) {

  ## name of the response variable
  response <- all.vars(formula)[1]
  ## data without missing values, response comes last
  data <- data[complete.cases(data), c(all.vars(formula)[-1], response)]
  ## data is factors only
  stopifnot(all(sapply(data, is.factor)))

  if (is.null(weights)) weights <- rep(1L, nrow(data))
  ## weights are case weights, i.e., integers
  stopifnot(length(weights) == nrow(data) &
    max(abs(weights - floor(weights))) < .Machine$double.eps)

  ## grow tree
  nodes <- growtree(id = 1L, response, data, weights)

  ## compute terminal node number for each observation
  fitted <- fitted_node(nodes, data = data)
  ## return rich constparty object
  ret <- party(nodes, data = data,
    fitted = data.frame("(fitted)" = fitted,
                        "(response)" = data[[response]],
                        "(weights)" = weights,
                        check.names = FALSE),
    terms = terms(formula))
  as.constparty(ret)
}

## ----mytree-4, echo = TRUE------------------------------------------
(myttnc <- mytree(Survived ~ Class + Age + Gender, data = ttnc))

## ----mytree-5, echo = FALSE, fig.height=8.5, fig.width=14-----------
plot(myttnc)

## ----mytree-pval, echo = TRUE---------------------------------------
nid <- nodeids(myttnc)
iid <- nid[!(nid %in% nodeids(myttnc, terminal = TRUE))]
(pval <- unlist(nodeapply(myttnc, ids = iid,
  FUN = function(n) info_node(n)$p.value)))

## ----mytree-nodeprune-----------------------------------------------
myttnc2 <- nodeprune(myttnc, ids = iid[pval > 1e-5])

## ----mytree-nodeprune-plot, echo = FALSE, fig.height=6, fig.width=10----
plot(myttnc2)

## ----mytree-glm, echo = TRUE----------------------------------------
logLik(glm(Survived ~ Class + Age + Gender, data = ttnc, 
           family = binomial()))

## ----mytree-bs, echo = TRUE-----------------------------------------
bs <- rmultinom(25, nrow(ttnc), rep(1, nrow(ttnc)) / nrow(ttnc))

## ----mytree-ll, echo = TRUE-----------------------------------------
bloglik <- function(prob, weights)
    sum(weights * dbinom(ttnc$Survived == "Yes", size = 1, 
                         prob[,"Yes"], log = TRUE))

## ----mytree-bsll, echo = TRUE---------------------------------------
f <- function(w) {
    tr <- mytree(Survived ~ Class + Age + Gender, data = ttnc, weights = w)
    bloglik(predict(tr, newdata = ttnc, type = "prob"), as.numeric(w == 0))
}
apply(bs, 2, f)

## ----mytree-node, echo = TRUE---------------------------------------
nttnc <- expand.grid(Class = levels(ttnc$Class),
  Gender = levels(ttnc$Gender), Age = levels(ttnc$Age))
nttnc

## ----mytree-prob, echo = TRUE---------------------------------------
predict(myttnc, newdata = nttnc, type = "node")
predict(myttnc, newdata = nttnc, type = "response")
predict(myttnc, newdata = nttnc, type = "prob")

## ----mytree-FUN, echo = TRUE----------------------------------------
predict(myttnc, newdata = nttnc, FUN = function(y, w)
  rank(table(rep(y, w))))

