Lab 7: Nearest Neighbors Algorithm, Toy Examples

Learning objectives

  • train and test data
  • using the knn() function
  • making k-NN predictions
  • evaluating predictions

1. Basic idea, test and train data

This toy example is taken from Machine Learning with R by Brett Lantz. Suppose we have four things with the following characteristics:

things <- data.frame(ingredient =  c("grape", "green bean" , "nuts" , "orange"),
                     sweetness = c(8,3,3,7),
                     crunchiness = c(5,7,6,3),
                     class = c("fruit", "vegetable", "protein", "fruit"))
things
##   ingredient sweetness crunchiness     class
## 1      grape         8           5     fruit
## 2 green bean         3           7 vegetable
## 3       nuts         3           6   protein
## 4     orange         7           3     fruit

Two of the things are fruit, one is vegetable and one protein. Suppose we wanted to classify tomato into one of the classes. Is tomato a fruit, vegetable or protein? Tomato has the following characteristics:

unknown <- data.frame(ingredient =  "tomato",
                     sweetness = 6,
                     crunchiness = 4,
                     class="unknown")
unknown
##   ingredient sweetness crunchiness   class
## 1     tomato         6           4 unknown

We now have two data frames: the things data frame where items are already classified (training data set), and the unknown data frame where we don’t know the class of the item (test data set). kNN algorithm will classify tomato by calculating the distance between tomato and each item in the training data set. The item that is closest to a tomato will determine how tomato is classified. If the closest item is a fruit, tomato will be classified as a fruit, if the closest item is a vegetable, tomato will be classified as a vegetable, etc. The distance between items is measured as the Euclidean distance in terms of sweetness and crunchiness. Let’s visualize our data by plotting the two characteristics (sweetness and crunchiness) of our items.

library(dplyr)
library(descr)
library(ggplot2)
ggplot(bind_rows(things, unknown)) + 
  geom_point(aes(x=sweetness, y=crunchiness, color=class),size=10) +
  geom_label(aes(x=sweetness, y=crunchiness, label=ingredient), hjust = 0, nudge_x = 0.25)+
  xlim(2,9) + ylim(3,8)

2. Using the knn() function

We will use the knn() function from the class package to do all the distance calculations and classification. The function has three arguments: the training data frame, the test data frame, and the vector that show the class of each item in the training data. Note that the training and test data frames need to have only the characteristics which we use for classification - no other variables. The knn() function returns a vector of prediction for each item in the test data set. In our case there is only one item in the test data set.

library(class) #contains knn function
pred <- knn(select(things, sweetness, crunchiness), 
            select(unknown,sweetness, crunchiness), things$class, k=1)
pred
## [1] fruit
## Levels: fruit protein vegetable

Ok, so knn says that tomato is a fruit. In terms of sweetness and crunchiness, tomato is closest to an orange and since orange is a fruit, we say tomato is fruit as well.

3. Making predictions

Let’s classify another item, say a carrot. Let’s add a carrot to the tomato in our test data set.

unknown <- data.frame(ingredient =  c("tomato", "carrot"),
                     sweetness = c(6,4),
                     crunchiness = c(4,9),
                     class=c("unknown", "unknown"))
unknown
##   ingredient sweetness crunchiness   class
## 1     tomato         6           4 unknown
## 2     carrot         4           9 unknown
pred <- knn(select(things, sweetness, crunchiness), 
            select(unknown,sweetness, crunchiness), things$class, k=1)
pred
## [1] fruit     vegetable
## Levels: fruit protein vegetable

Since we have two items in the test data set kNN returned two predictions, one for tomato and one for carrot. Carrot is classified as a vegetable.

4. Using k greater than one

Sometimes we don’t want to classify an item based on just one nearest neighbor but based on a couple of neighbors and determine the class based on the majority among the k nearest neighbors. Let’s set k equal to 4 and see what happens.

pred <- knn(select(things, sweetness, crunchiness), 
            select(unknown,sweetness, crunchiness), things$class, k=4)
pred
## [1] fruit fruit
## Levels: fruit protein vegetable

Both tomato and carrot were classified as fruit. This is because the algorithm took four nearest neighbors and took the majority class among those four neighbors. Since, the four neighbors is the entire training data set and two out of four items are fruit, knn classified both tomato and carrot as fruit.

IN-CLASS EXERCISE: 1. Load in the nhl2 data set (the one that includes cumulative wins).

library(readr) #let's use read_csv from package reader (it parses dates nicely)
nhl2 <- read_csv("https://www.dropbox.com/s/87kpl1f2yq6qexx/nhl%20with%20cum%20wins.csv?raw=1")
  1. Create a training data with games before 2016-04-09. Create a test data with games on and after 2016-04-09.
nhl_train <- nhl2 %>% filter(Date < as.Date("2016-04-09")) 
nhl_test <- nhl2 %>% filter(Date >= as.Date("2016-04-09")) 
  1. What are your predictions for those 17 games in the test data set?

  2. Use knn algorithm to predict outcomes of the games in the test data. Use cumwins_home and cumwins_visit as predictors. Also, set seed to 364, it will generate the same random numbers for all of us. knn uses a random draw if there is a ties in evaluating nearest neighbors.

ggplot(nhl2) + geom_point(aes(x=cumwins_home, y=cumwins_visit, color=as.factor(win), alpha=0.1))

set.seed(364)
win_pred <- knn(select(nhl_train, cumwins_home, cumwins_visit), 
            select(nhl_test, cumwins_home, cumwins_visit), nhl_train$win, k=3)
win_pred
##  [1] 1 1 0 1 1 1 0 0 1 0 1 0 1 0 0 1 0
## Levels: 0 1

5. Evaluating predictions

Let’s compare our predictions store in the pred vector to what actually happened.

crosstab(nhl_test$win, win_pred, prop.t = TRUE, plot=FALSE)
##    Cell Contents 
## |-------------------------|
## |                   Count | 
## |           Total Percent | 
## |-------------------------|
## 
## =====================================
##                 win_pred
## nhl_test$win        0       1   Total
## -------------------------------------
## 0                  5       3       8 
##                 29.4%   17.6%        
## -------------------------------------
## 1                  3       6       9 
##                 17.6%   35.3%        
## -------------------------------------
## Total              8       9      17 
## =====================================

We see that of the 8 games in which the home team lost (win=0) we predicted 5 of these losses. Of the 9 games that the home team won we predicted 6 of those wins. Overall, we predicted correctly 11 out of 17 games, giving us accuracy of about 65%.