Julia Community

Cover image for Naïve k-means
Martin Roa Villescas
Martin Roa Villescas

Posted on

Naïve k-means

This is the first of a series of posts on k-means clustering. In this post, we will focus on k-means clustering in its simplest form, commonly referred to as naïve k-means. In future posts, we will gradually improve this version by better understanding its inner workings and looking at it from a different angle.

The goal of k-means clustering is to partition a set of NN observations {x1,,xN}\{ \mathbf{x}_1, \ldots, \mathbf{x}_N \} into KK clusters. An observation xn\mathbf{x}_n corresponds to a point in a multidimensional space. Intuitively, each cluster kk can be thought of as a group of points that lie close to each other around its mean μk.\bm{\mu}_k. The goal of k-means clustering is to find an assignment of points to clusters {kn}\{ k_n \} that minimizes the sum of the squared Euclidean distances of each point xn\mathbf{x}_n to its closest mean μk.\bm{\mu}_k. We can do this through an iterative procedure involving two steps:

  1. Assign each point xn\mathbf{x}_n to its nearest cluster knk_n with mean μk\bm{\mu}_k :
    kn=arg mink(μkxn)2k_n = \text{arg}~\text{min}_k (\bm{\mu}_k - \mathbf{x_n})^2
  2. Update the cluster means {μk}\{ \bm{\mu}_k \} :
μk=xnkxnxnk1 \bm{\mu_k} = \frac{\sum\limits_{\bm{x_n} \in k} \bm{x_n}}{\sum\limits_{\bm{x_n} \in k} 1}

This two-step iterative procedure is repeated until convergence. Note that before we can start iterating, we need to choose initial values for the means μk.\bm{\mu}_k. One common approach is to set them to a random subset of KK observations. For the moment, we suppose that the value of KK is given.

Let us now look at an implementation using the Julia language.

We start start by importing the necessary packages: Plots.jl to visualize the result, RDatasets.jl to load the Iris flower data set, Distances.jl to use its squared Euclidean distance implementation and, Statistics.jl to use its mean implementation. We then load two features of the dataset into a matrix. For illustration purposes we will use two features. Note, however, that k-means clustering works with data of any number of dimensions. Finally, we choose an arbitrary value for K.K.

using Plots, RDatasets, Distances, Statistics
pyplot(leg=false, ms=6, border=true, fontfamily="Calibri",
       title="Naïve k-means")             # plot defaults
iris = dataset("datasets", "iris");       # load dataset
𝐱 = collect(Matrix(iris[:, 1:2])');      # select features
K = 3                                     # number of clusters
Enter fullscreen mode Exit fullscreen mode

Let's define a function to visualize the cluster points {xn}\{ \mathbf{x}_n \} and the means {μk}.\{ \bm{\mu}_k \}. Let's also represent the assignment of points to clusters {kn}\{ k_n \} using the marker's color.

function plot_clusters(𝐱, 𝛍, kₙ)
  plt = scatter(𝐱[1,:], 𝐱[2,:], c=kₙ)
  scatter!(𝛍[1,:], 𝛍[2,:], m=(:xcross,10), c=1:size(𝛍,2))
  return plt
Enter fullscreen mode Exit fullscreen mode

Last but not least, an elegant but naïve k-means implementation:

# Initialize the means for each cluster
𝛍 = 𝐱[:, rand(axes(𝐱, 2), K)]

anim = @animate for _ in 1:15

  # Find the squared distance from each sample to each mean
  d = pairwise(SqEuclidean(), 𝐱, 𝛍, dims=2)

  # 1. Assign each point to the cluster with the nearest mean
  kₙ = findmin(d, dims=2)[2] |> x -> map(i -> i.I[2], x)

  # 2. Update the cluster means
  for k in 1:K
    𝐱ₙₖ = 𝐱[:, dropdims((kₙ .== k), dims=2)]
    𝛍[:, k] = mean(𝐱ₙₖ, dims=2)

  plot_clusters(𝐱, 𝛍, kₙ)


gif(anim, joinpath(@__DIR__, "anim.gif"), fps=2)
Enter fullscreen mode Exit fullscreen mode

animation of the naive k-means

K-means clustering is one of the simplest and most popular unsupervised machine learning algorithms used to discover underlying patterns by grouping similar data. In this post, we studied the mechanics of this method in its simplest form. However, a number of questions remain unanswered. For example, how did we arrive to the equations describing the two-step iterative procedure? Is it guaranteed to converge? What are the shortcomings of k-means clustering as implemented here? Can we give a meaningful interpretation to this problem that gives us a deeper understanding as to what is actually happening in the iterative procedure? If so, can such interpretation provide us with tools to modify and improve the mechanism presented here? Stay tuned.

Discussion (3)

akvsak profile image
Ashok Kumar

I read your other articles also. You have a nice conversational style. Are you considering putting the complete code of each article in a repo?

mroavi profile image
Martin Roa Villescas Author

Hey @akvsak. Thanks for the kind words! Sure, I could do that, here you go: github.com/mroavi/julia-forem

akvsak profile image
Ashok Kumar

thanks. I followed your repo. I will use it to study your articles. 👏