## A simple classification tree with rpart

## functions for reading data from a table.
##X <- read.table("http://www.cse.chalmers.se/~chrdimi/downloads/fouille/geneTraining.txt")
##Y <- read.table("http://www.cse.chalmers.se/~chrdimi/downloads/fouille/geneTesting.txt")

X <- read.table("http://www.cse.chalmers.se/~chrdimi/downloads/fouille/geneTraining.txt")
Y <- read.table("http://www.cse.chalmers.se/~chrdimi/downloads/fouille/geneTesting.txt")

#X <- read.table("geneTraining.txt")
#Y <- read.table("geneTesting.txt")

## The number of attributes to be used.  There are a lot of
## attributes, with possibly 2^1000 partitions of the data.
nAttributes <- 1000

## Magic to give names to all attributes. This is required for the
## tree library.
xnam <- paste0("x", 1:nAttributes)
(fmla <- as.formula(paste("label ~ ", paste(xnam, collapse= "+"))))

## load the tree library
library("rpart")

## Settings for the tree fitting. These are not necessary, but here we are going to use these settings to make the algorithm like ID3.
fit.params = list(split="information") ## use information gain to split
fit.ctrl = rpart.control(minsplit = 2) ## always split - increase the number to simplify the tree

## fit a classification tree
fit.tree <- rpart(fmla, data = X, method="class", parms = fit.params, control = fit.ctrl)
plot(fit.tree) # plot it
text(fit.tree) # add labels

## Calculate classification error
## The predict function returns a series of labels
print("Classification accuracy in training:")
fit.train <- predict(fit.tree, X, tree = "class") 
acc <- 0;
for (t in 1:dim(X)[1]) {
    acc = acc + fit.train[t, 1 + X[t,]$label];
}
acc.train <- acc / dim(X)[1];
print(acc.train)


print("Classification accuracy in testing:")
fit.test <- predict(fit.tree, Y, tree = "class")
acc <- 0;
for (t in 1:dim(Y)[1]) {
    acc = acc + fit.test[t, 1 + Y[t,]$label];
}
acc.test <- acc / dim(Y)[1];
print(acc.test)


print("Cross validation")
printcp(fit.tree)


## cross-validated error estimate

## approx same result as rel. error from printcp(fit)
#     apply(xerr, 2, sum)/var(car.test.frame$Mileage) 
#     printcp(fit)