Flash-KMeans: Fast and Memory-Efficient Exact K-Means
Summary (Overview)
- Problem Identification: Standard GPU implementations of k-means are bottlenecked not by computational complexity but by low-level system constraints: 1) massive I/O from materializing the distance matrix in High Bandwidth Memory (HBM), and 2) severe atomic write contention during centroid updates.
- Core Innovations: Introduces two kernel-level optimizations:
- FlashAssign: Fuses distance computation with an online argmin operator, completely bypassing the explicit materialization of the distance matrix.
- Sort-Inverse Update: Transforms high-contention atomic scatters into regular, segment-level localized reductions by sorting assignments and constructing an inverse mapping.
- System Co-Design: Implements chunked stream overlap for out-of-core execution and a cache-aware compile heuristic to eliminate expensive auto-tuning, ensuring practical deployability.
- Performance Gains: Achieves up to 17.9× end-to-end speedup over the best baseline, with kernel-level speedups of 21.2× (FlashAssign) and 6.3× (Sort-Inverse Update). Outperforms cuML by 33× and FAISS by >200×.
- Scalability: Successfully scales to one billion points with a 10.5× speedup in out-of-core settings and reduces configuration tuning overhead by 175× with negligible (<0.3%) performance loss.
Introduction and Theoretical Foundation
K-means is a classical clustering algorithm traditionally used in offline data processing pipelines. Modern AI workloads (e.g., vector quantization, sparse routing in LLMs, token permutation in generative video models) have shifted its role to a high-frequency, online primitive invoked during training and inference. This new paradigm demands low latency, high throughput, and scalability on GPUs.
Existing algorithmic optimizations (e.g., using triangle inequalities, sampling) focus on reducing FLOPs but often fail to translate into wall-clock speedups on modern hardware. The primary bottlenecks are now memory bandwidth and data movement, not raw computation. Standard GPU implementations suffer from:
- I/O-bound Assignment: Explicitly materializing the full distance matrix , where , incurs massive HBM traffic.
- Atomic Write Contention: Scatter-style updates in the centroid stage cause many threads to atomically update the same "hot" cluster buffers, leading to serialization.
- System-Level Constraints: Large datasets exceeding VRAM require chunked processing with PCIe overhead, and dynamic input shapes trigger costly recompilation and auto-tuning.
Flash-KMeans addresses these bottlenecks without altering the mathematical formulation of Lloyd's algorithm, which minimizes:
through alternating assignment (Eq. 2) and update (Eq. 3) stages.
Methodology
Flash-KMeans restructures the execution dataflow around hardware constraints with three key components.
4.1 FlashAssign: Materialization-Free Assignment via Online Argmin
This kernel fuses distance computation with reduction to avoid creating .
- Online Argmin: For each point , the kernel maintains in registers a running minimum distance and corresponding centroid index . It processes centroids in tiles, computes local distances, updates with the tile-local minimum, and repeats across all tiles to find the global argmin.
- Tiling & Prefetch: Employs two-dimensional tiling over points and centroids with double-buffering and asynchronous prefetch to hide memory latency.
- I/O Complexity: Reduces dominant HBM traffic from to .
Algorithm 2: FlashAssign (materialization-free assignment)
Input: Data X ∈ R^{N × d}, centroids C ∈ R^{K × d}, point tile size B_N, centroid tile size B_K
Output: Assignments a ∈ {1, ..., K}^N
Precompute norms ∥x_i∥_2^2;
foreach point tile X_tile of size B_N in parallel do
Initialize on-chip running states: m ← +∞, a ← -1;
Prefetch the first centroid tile C_tile^{(0)} from HBM into on-chip buffer;
for t ← 0 to ⌈K/B_K⌉ - 1 do
if t + 1 < ⌈K/B_K⌉ then
Prefetch C_tile^{(t+1)} into the alternate buffer;
Compute local distances between X_tile and C_tile^{(t)} on chip;
Compute tile-local minima and indices (m̃, ã) for each point in X_tile;
Update running states using online argmin;
m ← min(m, m̃), and update a with the corresponding index;
Swap buffers;
Write final assignments a for X_tile to HBM;
4.2 Sort-Inverse Update: Low-Contention Centroid Aggregation
This kernel transforms irregular scatters into regular gathers.
- Explicit Inverse Mapping: First,
argsortthe assignment vector to getsorted_idx. This groups points by cluster ID without physically permuting the data matrix . - Segment-Level Aggregation: Each thread block (CTA) processes a contiguous chunk of the sorted sequence. It identifies segment boundaries (contiguous same cluster IDs), gathers the corresponding features from using
sorted_idx, and accumulates partial sums/counts on-chip. - Reduced Atomics: Atomic adds are issued only at segment boundaries (per cluster per chunk) instead of per point. The number of atomic operations drops from to .
Algorithm 3: Sort-Inverse Update (low-atomic centroid update)
Input: Points X ∈ R^{N × d} (original order), assignments a ∈ {1, ..., K}^N, chunk size B_N
Output: Centroid sums s ∈ R^{K × d}, counts n ∈ R^K, updated centroids C
Sort by cluster id: compute sorted_idx ← argsort(a);
Construct sorted cluster ids a_sorted[j] ← a[sorted_idx[j]];
Initialize s ← 0, n ← 0;
for l ← 0 to N - 1 step B_N do
r ← min(l + B_N, N);
Load a_sorted[l:r] and sorted_idx[l:r];
Identify contiguous segments of identical cluster ids in a_sorted[l:r];
for each segment (u, v, k) in a_sorted[l:r] do
Gather token features from original order using indices sorted_idx[u:v];
Accumulate local partial sum Δs_k and local partial count Δn_k on chip;
Atomic merge (once per segment): atomic_add(s_k, Δs_k) and atomic_add(n_k, Δn_k);
for k ← 1 to K do
c_k ← s_k / n_k;
4.3 Efficient Algorithm-System Co-Design
- Chunked Stream Overlap: For out-of-core execution, data is partitioned into chunks. Asynchronous CUDA streams overlap PCIe host-to-device transfers with computation using a double-buffer pattern.
- Cache-Aware Compile Heuristic: Instead of exhaustive auto-tuning for dynamic shapes, a heuristic selects kernel configurations (e.g., tile sizes) based on hardware cache sizes (L1, L2) and problem shape, drastically reducing compilation time.
Empirical Validation / Results
Evaluated on an NVIDIA H200 GPU against fast_pytorch_kmeans, fastkmeans, NVIDIA cuML, and FAISS.
5.2 Efficiency Evaluation
End-to-End Speedup: Flash-KMeans achieves consistent speedups across all workload regimes (large N/large K, large N/small K, small N/small K, batched). It outperforms the best baseline (fastkmeans) by up to 17.9×.
Table: Representative End-to-End Speedups
| Workload Regime | Example Configuration (N, K, D) | Speedup vs. Best Baseline |
|---|---|---|
| Large N, Large K | 1M, 64K, 512 | >5.4× |
| Large N, Small K | 8M, 1024, 128 | 17.9× |
| Small N, Small K (Batched) | N=16K, K=8K, D=512, B=32 | Consistent acceleration |
| Note: Standard PyTorch fails (OOM) in large-K regimes due to distance matrix materialization. |
Kernel-Level Breakdown:
- FlashAssign: Achieves up to 21.2× speedup (e.g., reducing latency from 122.5 ms to 5.8 ms for N=1M, K=8192).
- Sort-Inverse Update: Achieves up to 6.3× speedup for centroid update (e.g., for B=1, N=33M, K=4096).
5.3 Algorithm-System Co-Design Evaluation
Large-Scale Out-of-Core Processing:
- Scales to 1 billion points (N=10⁹, K=32768, D=128), completing an iteration in 41.4 seconds vs. 261.8 seconds for
fastkmeans(6.3× speedup). - For N=400M, K=16384, achieves 10.5× speedup (8.4s vs. 88.4s).
Fast Time-to-First-Run:
- Cache-aware heuristic reduces configuration search/compilation time by up to 175× (e.g., from >325 seconds to <2.5 seconds for N=8M, K=65536).
- The heuristic-selected configuration matches the performance of exhaustively tuned kernels within a <0.3% margin, providing near-optimal performance immediately.
Theoretical and Practical Implications
- Theoretical: Demonstrates that for classical algorithms deployed on modern hardware, I/O complexity and synchronization overheads can dominate theoretical FLOPs. Algorithmic innovations must be coupled with hardware-aware implementation.
- Practical: Flash-KMeans enables k-means as a viable online primitive in production AI systems:
- Exact Results: Provides mathematically exact Lloyd's algorithm, not an approximation.
- Deployability: Solves key deployment challenges: memory scalability (out-of-core support), dynamic shape handling (fast compilation), and consistent speedups across diverse workloads.
- Broad Impact: Accelerates applications in vector quantization, sparse routing for LLMs, KV cache compression, and generative video models.
Conclusion
Flash-KMeans re-architects the k-means algorithm for modern GPU systems by addressing fundamental memory and synchronization bottlenecks. Through the FlashAssign and Sort-Inverse Update kernels, coupled with system-level co-design for scalability and deployability, it delivers order-of-magnitude speedups while maintaining mathematical exactness. The work underscores the critical importance of algorithm-system co-design in unlocking the full potential of classical algorithms within contemporary AI infrastructure.