knn()
functionThis toy example is taken from Machine Learning with R by Brett Lantz. Suppose we have four things with the following characteristics:
things <- data.frame(ingredient = c("grape", "green bean" , "nuts" , "orange"),
sweetness = c(8,3,3,7),
crunchiness = c(5,7,6,3),
class = c("fruit", "vegetable", "protein", "fruit"))
things
## ingredient sweetness crunchiness class
## 1 grape 8 5 fruit
## 2 green bean 3 7 vegetable
## 3 nuts 3 6 protein
## 4 orange 7 3 fruit
Two of the things are fruit, one is vegetable and one protein. Suppose we wanted to classify tomato into one of the classes. Is tomato a fruit, vegetable or protein? Tomato has the following characteristics:
unknown <- data.frame(ingredient = "tomato",
sweetness = 6,
crunchiness = 4,
class="unknown")
unknown
## ingredient sweetness crunchiness class
## 1 tomato 6 4 unknown
We now have two data frames: the things
data frame where items are already classified (training data set), and the unknown
data frame where we don’t know the class of the item (test data set). kNN algorithm will classify tomato by calculating the distance between tomato and each item in the training data set. The item that is closest to a tomato will determine how tomato is classified. If the closest item is a fruit, tomato will be classified as a fruit, if the closest item is a vegetable, tomato will be classified as a vegetable, etc. The distance between items is measured as the Euclidean distance in terms of sweetness and crunchiness. Let’s visualize our data by plotting the two characteristics (sweetness and crunchiness) of our items.
library(dplyr)
library(descr)
library(ggplot2)
ggplot(bind_rows(things, unknown)) +
geom_point(aes(x=sweetness, y=crunchiness, color=class),size=10) +
geom_label(aes(x=sweetness, y=crunchiness, label=ingredient), hjust = 0, nudge_x = 0.25)+
xlim(2,9) + ylim(3,8)
knn()
functionWe will use the knn()
function from the class
package to do all the distance calculations and classification. The function has three arguments: the training data frame, the test data frame, and the vector that show the class of each item in the training data. Note that the training and test data frames need to have only the characteristics which we use for classification - no other variables. The knn()
function returns a vector of prediction for each item in the test data set. In our case there is only one item in the test data set.
library(class) #contains knn function
pred <- knn(select(things, sweetness, crunchiness),
select(unknown,sweetness, crunchiness), things$class, k=1)
pred
## [1] fruit
## Levels: fruit protein vegetable
Ok, so knn says that tomato is a fruit. In terms of sweetness and crunchiness, tomato is closest to an orange and since orange is a fruit, we say tomato is fruit as well.
Let’s classify another item, say a carrot. Let’s add a carrot to the tomato in our test data set.
unknown <- data.frame(ingredient = c("tomato", "carrot"),
sweetness = c(6,4),
crunchiness = c(4,9),
class=c("unknown", "unknown"))
unknown
## ingredient sweetness crunchiness class
## 1 tomato 6 4 unknown
## 2 carrot 4 9 unknown
pred <- knn(select(things, sweetness, crunchiness),
select(unknown,sweetness, crunchiness), things$class, k=1)
pred
## [1] fruit vegetable
## Levels: fruit protein vegetable
Since we have two items in the test data set kNN returned two predictions, one for tomato and one for carrot. Carrot is classified as a vegetable.
Sometimes we don’t want to classify an item based on just one nearest neighbor but based on a couple of neighbors and determine the class based on the majority among the k nearest neighbors. Let’s set k equal to 4 and see what happens.
pred <- knn(select(things, sweetness, crunchiness),
select(unknown,sweetness, crunchiness), things$class, k=4)
pred
## [1] fruit fruit
## Levels: fruit protein vegetable
Both tomato and carrot were classified as fruit. This is because the algorithm took four nearest neighbors and took the majority class among those four neighbors. Since, the four neighbors is the entire training data set and two out of four items are fruit, knn classified both tomato and carrot as fruit.
IN-CLASS EXERCISE: 1. Load in the nhl2
data set (the one that includes cumulative wins).
library(readr) #let's use read_csv from package reader (it parses dates nicely)
nhl2 <- read_csv("https://www.dropbox.com/s/87kpl1f2yq6qexx/nhl%20with%20cum%20wins.csv?raw=1")
nhl_train <- nhl2 %>% filter(Date < as.Date("2016-04-09"))
nhl_test <- nhl2 %>% filter(Date >= as.Date("2016-04-09"))
What are your predictions for those 17 games in the test data set?
Use knn algorithm to predict outcomes of the games in the test data. Use cumwins_home
and cumwins_visit
as predictors. Also, set seed to 364, it will generate the same random numbers for all of us. knn uses a random draw if there is a ties in evaluating nearest neighbors.
ggplot(nhl2) + geom_point(aes(x=cumwins_home, y=cumwins_visit, color=as.factor(win), alpha=0.1))
set.seed(364)
win_pred <- knn(select(nhl_train, cumwins_home, cumwins_visit),
select(nhl_test, cumwins_home, cumwins_visit), nhl_train$win, k=3)
win_pred
## [1] 1 1 0 1 1 1 0 0 1 0 1 0 1 0 0 1 0
## Levels: 0 1
Let’s compare our predictions store in the pred
vector to what actually happened.
crosstab(nhl_test$win, win_pred, prop.t = TRUE, plot=FALSE)
## Cell Contents
## |-------------------------|
## | Count |
## | Total Percent |
## |-------------------------|
##
## =====================================
## win_pred
## nhl_test$win 0 1 Total
## -------------------------------------
## 0 5 3 8
## 29.4% 17.6%
## -------------------------------------
## 1 3 6 9
## 17.6% 35.3%
## -------------------------------------
## Total 8 9 17
## =====================================
We see that of the 8 games in which the home team lost (win=0) we predicted 5 of these losses. Of the 9 games that the home team won we predicted 6 of those wins. Overall, we predicted correctly 11 out of 17 games, giving us accuracy of about 65%.