Grouping similar objects, using K-means clustering

Apr 29, 2023 · 6 mins read
Grouping similar objects, using K-means clustering

Grouping similar objects, using K-means clustering

Sometimes when programming you have a set of “things” (objects, pixels, positions, etc.) that you need to group together, based on their similaries, which may be position, colour or other things. K-means clustering can be a nice way of solving this problem.

I’ll use image segmentation as an example use case in this blog post.

I’ve uploaded the full implementation (C++) here.

The problem

We want to “simplify” a detailed image by reducing the number of distinct colours used. So for example: the sky should use only one or two shades of blue - maybe the clouds will be white. And the grass should be green, without too much variation. The result will be an image that looks much simpler, and might resemble a simple draft or drawing.

This is the input image:

And this is what we want to get:

This is actually a clustering problem. We want to identify groups of “similar” pixels, and assign the same colour to all of them. So we can use K-means clustering to solve this problem.

The algorithm

  1. For each pixel, create a 5 dimensional point containing RGB colour and XY position (R,G,B,X,Y).
  2. Create N clusters at random positions in this 5D space.
  3. For each point (pixel): Find nearest cluster.
  4. For each cluster: Calculate new centre as average position of all points within cluster.
  5. Repeat from 3 many times (until it “converges”).

Explanation and implementation

The image consists of a number of pixels. These pixels have:

  • A position: a 2D vector (X,Y)
  • A colour: a 3D vector (R,G,B)

We will need to first identify pixels that are similar. We define “similar” as having a similar colour and position. For each pixel we can then create a 5-dimensional vector representing its position and colour (we combine these two vectors into one):

class Point
{
public:
    float r, g, b, x, y);
    Point(float r, float g, float b, float x, float y)
    : r(r), g(g), b(b), x(x), y(y)
    {
    }
}

To find out how similar two pixels (a and b) are, we can simply measure the vector length:

const float distance = (b - a).length();

where length is the eucledian distance:

inline float length() const
{
    return std::sqrt(r*r + g*g + b*b + x*x + y*y);
}

Pixels that have similar colour and are positioned near to each other should then be grouped together.

Step 1: Creating the initial clusters

void createInitialClusters()
{
    const unsigned int clusterCount = 8;
    for (unsigned int i = 0; i < clusterCount; ++i)
    {
        Cluster cluster;
        cluster.numPoints = 0;
        cluster.centre.x = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
        cluster.centre.y = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
        cluster.centre.r = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
        cluster.centre.g = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
        cluster.centre.b = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
        clusters.push_back(cluster);
    }
}

Here we simply calculate a random 5D-vector (X,Y,R,G,B), and assign it as the centre of a new cluster.

Step 2: Finding the nearest cluster

void findNearestClusters()
{
    for (size_t iPoint = 0; iPoint < points.size(); ++iPoint)
    {
        float nearestDistance = std::numeric_limits<float>::infinity();
        unsigned int nearestIndex = 0;
        const Point point = points[iPoint];
        for (unsigned int iCluster = 0; iCluster < clusters.size(); ++iCluster)
        {
            const Cluster& cluster = clusters[iCluster];
            const float distance = (cluster.centre - point).length();
            if (distance < nearestDistance)
            {
                nearestIndex = iCluster;
                nearestDistance = distance;
            }
        }
        pointClusterIds[iPoint] = nearestIndex;
    }
}

For each point, we iterate over all clusters and find the distance to the centre point of each of them. We then assign it to the nearast one.

Step 3: Recalculating the cluster centres

void recalculateClusterCentres()
{
    // Clear clusters
    for (auto& cluster : clusters)
    {
        cluster.centre = Point(0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
        cluster.numPoints = 0;
    }
    // Calculate new cluster centres as avarage of points within cluster.
    for (size_t iPoint = 0; iPoint < points.size(); ++iPoint)
    {
        const Point& point = points[iPoint];
        const unsigned int clusterIndex = pointClusterIds[iPoint];
        Cluster& cluster = clusters[clusterIndex];
        cluster.centre = cluster.centre + point;
        cluster.numPoints++;
    }
    for (auto& cluster : clusters)
    {
        cluster.centre = cluster.centre / static_cast<float>(cluster.numPoints);
    }
}

Here we calculate the average position of all points inside of each cluster. We do it by accumulating the positions together in the cluster’s centre variable, and then divide the by the total number of points (numPoints) inside the cluster. These accumulated values will grow very big if we have many points. In that case you would need to either use a double precision float, or incrementally calculate the average cluster centre.

Repeat

As we repeat this process over and over again, the cluster centres will move slower and slower. When they finally stop moving, the algorithm is said to “converge”. At that point we have found some ideal clusters.

If performance is important it is of course possible to stop after N iterations. You could still expect acceptable results. How many iterations are needed for good results depends on the data, and on the randomised initial clusters.

A possible improvement

K-means++ is an improved version of this algorithm, where instead of using random initial clusters, we do some extra work to ensure that the initial clusters are better. This will make the algorithm converge much faster, giving much better results if we stop early.

The general idea of K-means++ is to pick N cluster centres that are evenly spaced out. That is, the initial clusters should ideally not be too similar.

In this version we first calculate one cluster centre, and then for the ramaining N-1 clusters we:

  1. Pick first cluster centre at random
  2. For each other cluster: Pick one point that lies far away from all picked clusters, and use as cluster centre.

See implementations here.

Other use cases

I recently used K-means++ at work. We had some 3D scenes containing geometry that would sometimes be sparsely separated, with great distances between them. To find an ideal position to place the camera at at, where we are sure that there is some geometry, I used K-means++ to group the geometry into some clusters based on their locations.

Clustering algorithms can also be used for matchmaking and other problems where you need to group objects based on some properties.

Acknowledgements

This project uses free art released by Ghibli here: https://www.ghibli.jp/info/013344/

Sharing is caring!