Data Science

K-Means at Scale: Why It Falls Apart and What's Next

K-means is the algorithm everyone learns in their intro ML course and then uses incorrectly for the rest of their career. It's deceptively simple: pick K cluster centers, assign each point to its nearest center, move the centers to the mean of their assigned points, repeat. Three lines of pseudocode. The problem is that this simplicity hides a series of landmines that explode when you apply it to real data at scale.

At 10,000 points, k-means finishes in milliseconds and nobody worries about implementation details. At 10 million points, choice of initialization method changes runtime by 100x. At 1 billion points, the standard algorithm doesn't fit in memory and you need fundamentally different approaches. A recent paper on Flash-KMeans tackles exactly this — achieving exact k-means results with dramatically reduced memory and faster convergence. Let's dig into what makes k-means hard at scale and how modern variants solve it.

What Standard K-Means Actually Does

Lloyd's algorithm — what people mean when they say 'k-means' — has two steps per iteration. The assignment step: for each data point, compute its distance to all K centroids and assign it to the nearest one. The update step: recompute each centroid as the mean of all points assigned to it.

import numpy as np
def kmeans_lloyd(X, K, max_iter=100):
n, d = X.shape
# Random initialization (bad, but we'll fix this later)
centroids = X[np.random.choice(n, K, replace=False)]
for _ in range(max_iter):
# Assignment: O(n * K * d) — this is the bottleneck
distances = np.linalg.norm(X[:, None] - centroids[None, :], axis=2)
labels = np.argmin(distances, axis=1)
# Update: O(n * d)
new_centroids = np.array([
X[labels == k].mean(axis=0) for k in range(K)
])
if np.allclose(centroids, new_centroids):
break
centroids = new_centroids
return labels, centroids

The assignment step costs O(n × K × d) per iteration, where n is the number of points, K is the number of clusters, and d is the dimensionality. For n = 1 billion, K = 1000, d = 128 (a realistic embedding clustering scenario), that's 128 trillion floating-point operations per iteration. Even at a teraflop of compute, that's 128 seconds per iteration, and k-means typically needs 10-50 iterations to converge.

The Initialization Problem

Before k-means runs a single iteration, it has to choose initial centroid positions. This choice matters far more than most people realize — bad initialization can lead to convergence on a solution that's arbitrarily worse than optimal.

Random initialization (picking K random data points as initial centroids) is simple but unreliable. If two initial centroids happen to land in the same cluster, one cluster will go unrepresented. The algorithm will converge, but to a suboptimal solution. With K = 100, the probability of at least one collision is high.

K-means++ (Arthur and Vassilvitskii, 2007) is the standard fix. It picks the first centroid randomly, then picks each subsequent centroid with probability proportional to its squared distance from the nearest existing centroid. Points far from any existing centroid are more likely to be chosen, spreading the initial centroids across the data. This provides an O(log K) approximation guarantee — the solution is provably within O(log K) of optimal.

def kmeans_plus_plus_init(X, K):
n, d = X.shape
centroids = [X[np.random.randint(n)]]
for _ in range(1, K):
# Distance from each point to nearest existing centroid
dists = np.min([
np.sum((X - c) ** 2, axis=1) for c in centroids
], axis=0)
# Sample proportional to distance squared
probs = dists / dists.sum()
next_idx = np.random.choice(n, p=probs)
centroids.append(X[next_idx])
return np.array(centroids)

The catch: k-means++ initialization itself is O(n × K × d) — it requires K passes through the entire dataset. For large K and large n, initialization takes longer than several iterations of Lloyd's algorithm. Scalable variants like k-means|| (Bahmani et al., 2012) reduce the number of passes by over-sampling candidates in parallel, then consolidating them into K centroids.

The Memory Wall

Standard k-means needs the entire dataset in memory simultaneously. For the assignment step, you need to access every data point and compare it to every centroid. With 1 billion 128-dimensional float32 vectors, the data alone requires 512 GB — more than most single machines have.

The naive solution is mini-batch k-means: sample a random subset of points each iteration, update centroids based on the sample. This works but converges more slowly and to a slightly different (usually slightly worse) solution than exact k-means. For many applications, the quality loss is acceptable. For others — quantization codebooks in vector search, for example — the quality difference matters.

Flash-KMeans takes a different approach. Instead of processing the full dataset in memory, it uses a streaming pass with intelligent caching. The key insight: most points don't change cluster assignments between iterations. If a point is firmly in cluster 7 (much closer to centroid 7 than to any other centroid), checking all K distances is wasted work. By maintaining distance bounds, Flash-KMeans can skip the full distance computation for most points in most iterations.

Triangle Inequality Optimizations

Elkan's algorithm (2003) uses the triangle inequality to skip unnecessary distance calculations. The triangle inequality says: the distance from point P to centroid A is at most the distance from P to centroid B plus the distance from B to A. If you know that P is currently assigned to centroid B, and you know the distance between centroids A and B, you can sometimes prove that A is too far from P without computing the actual distance.

In practice, this eliminates 80-95% of distance calculations after the first few iterations, when most points are already near their correct centroid. The remaining distance calculations are for points near cluster boundaries — the only ones that might actually change assignment.

The trade-off: Elkan's algorithm requires O(n × K) additional memory to store the distance bounds (lower bounds from each point to each centroid, plus upper bounds to the assigned centroid). For large K, this memory cost can be substantial. Hamerly's algorithm reduces this to O(n) memory by maintaining only a single lower bound per point, at the cost of fewer pruned calculations.

GPU Acceleration

K-means is embarrassingly parallel in the assignment step: every point's distance computation is independent. This makes it a natural fit for GPU acceleration. NVIDIA's cuML library and Facebook's FAISS both include GPU k-means implementations that achieve 10-50x speedup over CPU implementations for large datasets.

The wrinkle is that GPU memory is limited. An A100 has 80 GB of memory — enough for about 150 million 128-dimensional vectors. Larger datasets require either multi-GPU partitioning or a streaming approach where data is loaded in chunks, processed on the GPU, and results aggregated on the CPU.

# Using FAISS for GPU-accelerated k-means
import faiss
import numpy as np
# 10 million 128-dimensional vectors
n, d, K = 10_000_000, 128, 1000
X = np.random.randn(n, d).astype('float32')
# CPU k-means (for comparison)
kmeans_cpu = faiss.Kmeans(d, K, niter=20, verbose=True)
kmeans_cpu.train(X)  # ~120 seconds
# GPU k-means (single GPU)
kmeans_gpu = faiss.Kmeans(d, K, niter=20, verbose=True, gpu=True)
kmeans_gpu.train(X)  # ~8 seconds — 15x faster

When K-Means Is the Wrong Choice

Before optimizing your k-means implementation, consider whether k-means is the right algorithm. It makes strong assumptions that don't hold for many real datasets.

  • Spherical clusters. K-means assumes clusters are roughly spherical and equally sized. If your clusters are elongated, irregularly shaped, or have vastly different sizes, k-means will split large clusters and merge small ones. Gaussian Mixture Models (GMMs) handle elliptical clusters. DBSCAN handles arbitrary shapes.
  • Known K. You have to specify the number of clusters in advance. If you don't know K, you need to run k-means multiple times with different values and use a metric (silhouette score, elbow method) to pick the best one. This multiplies the total compute by the number of K values you try.
  • Euclidean distance. K-means uses Euclidean distance by default. For text embeddings, cosine similarity is usually more appropriate. You can work around this by L2-normalizing your vectors (which makes Euclidean distance equivalent to cosine distance), but it's easy to forget.
  • Outlier sensitivity. A single outlier far from any cluster will pull its assigned centroid toward it. Robust variants like k-medoids (which uses median instead of mean) handle outliers better but are more expensive.

Practical Advice

Having used k-means on datasets from thousands to billions of points, here's what I've learned matters most:

  1. Always use k-means++ initialization. Random initialization is never worth the risk. The difference in final cluster quality is often 10-30%, and k-means++ adds negligible overhead for small to medium K.
  2. Run multiple times. K-means finds a local optimum, not a global one. Run it 5-10 times with different random seeds and take the best result (lowest total within-cluster distance). This is cheap insurance against bad runs.
  3. Normalize your features. If one feature has range [0, 1000000] and another has range [0, 1], the high-range feature will dominate the distance calculation. Standardize (subtract mean, divide by standard deviation) or min-max normalize before clustering.
  4. Use FAISS for large-scale clustering. If you have more than a million points, scikit-learn's k-means will be slow. FAISS's implementation is heavily optimized and supports GPU acceleration out of the box.
  5. Consider approximate methods for exploratory work. Mini-batch k-means is 10x faster than exact k-means and gives results that are usually good enough for exploration. Use exact k-means for production clustering where quality matters.

K-means is one of those algorithms that's easy to use, hard to use well, and worth understanding deeply. The gap between a naive implementation and an optimized one — in both speed and result quality — is enormous. At small scale, none of this matters. At the scale where it does matter, understanding initialization, memory management, distance pruning, and GPU acceleration is the difference between clustering that takes hours and clustering that takes minutes.