[Machine Learning] K-Nearest Neighbor Analysis in R (1)
In order to understand our sample data set and some concepts in machine learning, please visit this web site.[Limitation of K-nearest Analysis]
The notable limitation that KNN has is that it doesn't have any capability of sorting categorical variables. Think about the column comprised of male and female like our member roster. We can't apply KNN to this type of data as this algorithm deals with the distance. The categorical value can't have the distance. There are many types of categorical values. For example, Male vs Female, Child vs Adult, Royal customers vs non-royal customers, Core value vs non-core values. Open your company's database. You'll see there are a lot of categorical values in there.
In order to overcome this obstacle, Classification Tree arises as an alternative.
[Machine learning mimics human's logical thoughts]
Human has an "intuition." The goal of machine learning is to give the machine "the intuitive power" like Human. Actually, K-nearest algorithm uses human's intuition - comparing and finding seemingly similar one. Classification Tree(aka CART) uses human's decision criteria. First let me get into the actual code first. I'll explain one by one later.
[How Classification Tree works]
The basic principle is to use "decision tree."
For example, I have the table below. I want to predict the who is the "student" based upon the age and income.
Age | Income | Occupation |
20 | $1000 | Student |
50 | $50,000 | Non-student |
15 | $100 | Student |
34 | $30,000 | Non-student |
Let's put aside the machine learning theory for now. Let's rely on our common sense first. We can think that students are generally less than 25 years old. Let's use our decision modeling first.
If( age > 25 ) output = student
else output = non-student
However, Mark Zuckerberg established Facebook before he became 25. So, young entrepreneurs should be excluded from the student classification. Our common sense says that more than $10,000 income is not likely to be a part time job. So, let's include our logic here.
If( age > 25 ) output = student
else {
if( income > $10,000) output = non-student
else output = student
}
This is how classification tree works. You can see this decision criteria later. Now, it's time to look at the code. Actually, classification algorithm uses "information entropy". The information theory was invented by Shannon. This theory is widely adopted in your cell phone. As this is the beyond our scope, I want to skip the detailing explanation on information entropy. Those who are interested in can search for information entropy on the internet. It's available at a lot of web sites and youtube pages.
[Codes]
# Classification Tree with rpart
library("rpart")
library("gmodels")
normalize <- function(x) {
#You can normalize the data. However, as classification tree doesn't use the distance, normalization is of less use.
mean_x <- mean(x)
stdev_x <- sd(x)*sqrt((length(x)-1)/(length(x)))
num <- x - mean_x
denom <- stdev_x
return (num/denom)
}
iris <- read.csv(url("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"), header = FALSE) #There are great sample data offered by UCI. Let's use this!
names(iris) <- c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "Species")
set.seed(1234)
ind <- sample(2, nrow(iris), replace=TRUE, prob=c(0.7, 0.3))
#Unlike KNN, you don't need to seperate the label from the training data because the machine needs to learn
iris.training <- iris[ind==1, 1:5]
#However, just like KNN, you need to seperate the label from the test data
iris.test <- iris[ind==2, 1:4]
iris.testLabels <- iris[ind==2, 5]
# grow tree
# y ~ x1 + x2 + x3+ x4
# y: what we want to know
# x1, x2, x3, x4: what we know
fit <- rpart(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width,
method="class", data=iris.training)
# plot tree
plot(fit, uniform=TRUE, main="Classification Tree for Iris")
text(fit, use.n=TRUE, all=TRUE, cex=.8)
# create attractive postscript plot of tree
iris_pred <- predict(fit, iris.test, type = "class")
#Confusion Matrix
CrossTable(x = iris.testLabels, y = iris_pred, prop.chisq=FALSE)
[Output]
<Confusion Matrix>
In this case, the accuracy is 36/38 = 94.78%
No comments:
Post a Comment