data_mining's People
data_mining's Issues
Classification and Regression Tree in R
We compare Classification And Regression Tree (CART) and Support Vector Machine (SVM) on a dataset kyphosis
(in R package rpart
). The dataset contains 81 rows and 4 columns, representing data on children who have had corrective spinal surgery. The 4 columns are
Kyphosis -- a factor with levels absent present indicating if a kyphosis (a type of deformation) was present after the operation.
Age -- in months
Number -- the number of vertebrae involved
Start-- the number of the first (topmost) vertebra operated on.
We are going to use Age, Number and Start to predict Kyphosis. We compare CART and SVM by using misclassification rate as measure of error. For CART, we prune the decision trees by different complexity parameter cp
. For SVM, we select the tuning parameter.
CART
n <- nrow(kyphosis)
index <- sample(n, 0.8 * n)
train <- kyphosis[index, ]
test <- kyphosis[-index, ]
fit <- rpart(Kyphosis ~ Age + Number + Start,
method="class", data = train)
pred <- predict(fit, test[, -1])
error <- mean((pred[, 1] > 0.5) != (test[, 1] == "absent"))
## classification error for different pruning parameter
mean(replicate(1000, select_tree())) ## 21.0%
mean(replicate(1000, select_tree(cp = 0.9))) ## 21.1%
mean(replicate(1000, select_tree(cp = 0.8))) ## 21.2%
mean(replicate(1000, select_tree(cp = 0.7))) ## 20.5%
mean(replicate(1000, select_tree(cp = 0.6))) ## 21.2%
mean(replicate(1000, select_tree(cp = 0.5))) ## 21.2%
mean(replicate(1000, select_tree(cp = 0.4))) ## 21.1%
mean(replicate(1000, select_tree(cp = 0.3))) ## 23.1%
mean(replicate(1000, select_tree(cp = 0.2))) ## 24.1%
mean(replicate(1000, select_tree(cp = 0.1))) ## 24.0%
## svm
mean(replicate(1000, select_svm(cost = 10, gamma = 0.2))) ## 18.3%
svm_tune <- tune(svm, train.x=kyphosis[, -1], train.y= kyphosis[, 1],
kernel="radial", ranges=list(cost=10^(-1:2), gamma = seq(0.1, 2, 0.1)))
Conclusion: The best misclassification error achieved by CART is 20.5%, compared to 18.3% by SVM. So SVM is a little bit better than CART.
## a function to prune decision tree by cross-validation
select_tree <- function(tp = 0.8, cp = 1) {
n <- nrow(kyphosis)
index <- sample(n, tp * n)
train <- kyphosis[index, ]
test <- kyphosis[-index, ]
fit <- rpart(Kyphosis ~ Age + Number + Start,
method="class", data = train)
fit1 <- prune(fit, cp = cp)
pred <- predict(fit1, test[, -1])
error <- mean((pred[, 1] > 0.5) != (test[, 1] == "absent"))
error
}
select_svm <- function(tp = 0.8, ...) {
n <- nrow(kyphosis)
index <- sample(n, tp * n)
train <- kyphosis[index, ]
test <- kyphosis[-index, ]
fit <- svm(Kyphosis ~ Age + Number + Start, data = train, ...)
pred <- predict(fit, test[, -1])
error <- mean(pred != test[, 1])
error
}
Comparing Naive Bayes and SVM for iris dataset
Both naive Bayes and support vector machine (SVM) can be used for classification. We compare the performance of them by using a classical data set iris
.
Accuracy
We conduct 1000 simulations. For each simulation we divide the data into a training set and a testing set. We use the training set to fit the model and predictions are made for the testing set. Classification error is calculated. After all 1000 simulations, the mean and standard deviations of the classification error is reported.
> nb_svm(iris, B = 1000)
naiveBayes svm
mean 0.04688000 0.04222000
std 0.02563831 0.02465927
From simulation we find svm produces 4.2% classification error on average, slightly less than 4.7% by naive Bayes.
Computation time
library(microbenchmark)
> microbenchmark(naiveBayes(iris[,1:4], iris[, 5]), svm(iris[,1:4], iris[, 5]), times = 1000)
Unit: milliseconds
expr min lq mean
naiveBayes(iris[, 1:4], iris[, 5]) 1.302364 1.432913 1.554624
svm(iris[, 1:4], iris[, 5]) 3.105816 3.317584 3.557402
median uq max neval
1.492442 1.575168 16.33110 1000
3.433982 3.572570 22.91197 1000
By 1000 simulations, we find naiveBayes just needs 44% of time needed by SVM.
Therefore, SVM is more accurate than naive Bayes, but needs more computation time. This result is only obtained for dataset iris
, which contains 150 observations on 4 predictors. For larger dataset, we expect naiveBayes to be even faster, compared to SVM.
Appendix: R code
## 1. load the library
library(e1071)
## 2. divide the data into a training set and a testing set
n <- nrow(iris)
index <- sample(n, n * 2 / 3)
training <- iris[index, ]
testing <- iris[-index, ]
## 3. naive Bayes model
mod1 <- naiveBayes(training[, 1:4], training[, 5])
pred1 <- predict(mod1, testing[, 1:4])
err1 <- mean(pred1 != testing[, 5])
## 4. SVM
mod2 <- svm(training[, 1:4], training[, 5])
pred2 <- predict(mod2, testing[, 1:4])
err2 <- mean(pred2 != testing[, 5])
## 5. Monte Carlo simulation
nb_svm <- function(data, B = 100, training.p = 2/3) {
n <- nrow(data)
err1 <- err2 <- numeric(B)
for(i in 1:B) {
index <- sample(n, n * training.p)
training <- iris[index, ]
testing <- iris[-index, ]
mod1 <- naiveBayes(training[, 1:4], training[, 5])
pred1 <- predict(mod1, testing[, 1:4])
err1[i] <- mean(pred1 != testing[, 5])
mod2 <- svm(training[, 1:4], training[, 5])
pred2 <- predict(mod2, testing[, 1:4])
err2[i] <- mean(pred2 != testing[, 5])
}
plot(density(err1))
lines(density(err2), col = "red")
legend("topright", legend = c("naiveBayes", "svm"), lty = c(1, 1), col = c("black", "red"))
res1 <- c(mean(err1), sd(err1))
res2 <- c(mean(err2), sd(err2))
res <- data.frame(naiveBayes = res1, svm = res2)
rownames(res) <- c("mean", "std")
res
}
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.