11.1. K-means Clustering

So far, we have learned a lot of supervised learning algorithms (eg., Decision Tree, Random Forest), in which labelled or known outcomes are given. In contrast, unsupervised learning uses unlabeled data to discover patterns that help solve for clustering or association problems, and K-means clustering is one of the simplest and most popular unsupervised learning algorithms.

From wiki:

Given a set of observations \((\boldsymbol{x}_1, \boldsymbol{x}_2, \dots, \boldsymbol{x}_n)\), where each observation is a d-dimensional real vector, k-means clustering aims to partition the n observations into k (≤ n) sets \(\boldsymbol{S} = \{S_1, S_2, \dots, S_k\}\) so as to minimize the within-cluster sum of squares (WCSS) (i.e. variance). Formally, the objective is to find:

\[\begin{equation*} \underset{S}{\mathrm{argmin}} \sum_{i=1}^{k}\sum_{\boldsymbol{x}\in S_i} \left\|\boldsymbol{x}-\boldsymbol{\mu}_i \right\|^2 \end{equation*}\]

where \(\boldsymbol{\mu}_i\) is the mean of points in \(S_i\).

11.1.1. Lloyd’s algorithm

  1. Initialize \(\boldsymbol{\mu}_i\) randomly

  2. Alternate:

    • assignment: \(S_i\) <- \(\underset{S}{\mathrm{argmin}} \sum_{i=1}^{k}\sum_{\boldsymbol{x}\in S_i} \left\|\boldsymbol{x}-\boldsymbol{\mu}_i \right\|^2\) for all i

    • update: \(\mu_i\) <- \(\frac{1}{|S_i|}\sum_{\boldsymbol{x}_j\in S_i}{\boldsymbol{x}_j}\)

  3. When iter = max_iter or when the assignments do not change, algorithm terminates

Source: https://en.wikipedia.org/wiki/K-means_clustering#Gaussian_mixture_model

11.1.2. Implementation

Like many other popular algorithms, K-means can also be implemented by importing the scikit-learn package

11.1.3. Iris data

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target

We first plot the observations by their true species.

fig, ax = plt.subplots()
scatter = ax.scatter(X[:,0], X[:,1], c=y)
ax.legend(*scatter.legend_elements(),
          loc="upper left", title="Species")
plt.show()
../_images/k_means_clustering_4_0.png

11.1.3.1. Training K-means

Now let’s pretend we do not know the true borough of each data point, and try to predict it with their longitude and latitude.

from sklearn.cluster import KMeans

Kmean = KMeans(n_clusters = 3)
Kmean.fit(X)
KMeans(n_clusters=3)

Several parameters can be modified in the algorithm, see https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html

11.1.3.2. Results

The predicted centroids can be obtained by Kmean.cluster_centers_.

print(Kmean.cluster_centers_)

fig = plt.figure()

plt.scatter(Kmean.cluster_centers_[:,0], Kmean.cluster_centers_[:,1], c="black", s=200, marker='s')
plt.show()
[[5.006      3.428     ]
 [6.81276596 3.07446809]
 [5.77358491 2.69245283]]
../_images/k_means_clustering_8_1.png

The trained labels can be obtained by Kmean.labels_, and now we can compare them to the true values.

fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(10, 5),
                        constrained_layout=True)

axs[0].scatter(X[:,0], X[:,1], c = y, alpha = 0.5)
axs[0].set_title("Plot by True Labels")

axs[1].scatter(Kmean.cluster_centers_[:,0], Kmean.cluster_centers_[:,1], 
                marker = "s", c = "black", s = 400, alpha = 1) # plot predicted centroids
axs[1].scatter(X[:,0], X[:,1], c = Kmean.labels_, alpha = 0.5) 
axs[1].set_title("Plot by Predicted Labels")

plt.show()
../_images/k_means_clustering_10_0.png

The algorithm seems to be good yet there are several “misclassified” data points, mainly in the two clusters that are more mixed (non-separable). This is reasonable since K-means only minimizes within-cluster variances (squared Euclidean distances).

11.1.3.3. Prediction

If we have some new data, we can use Kmean.predict to predict which cluster they belong to.

sample_test=np.array([[3, 4], [7, 4]])

Kmean.predict(sample_test)
array([0, 1], dtype=int32)

11.1.4. Comment

K-means is a powerful unsupervised clustering algorithm that can be easily understood and implemented, but there are still some disadvantages.