Summary (Overview)

  • New Sparse Attention Mechanism: MiniMax Sparse Attention (MSA) is a blockwise sparse attention mechanism built on Grouped Query Attention (GQA), using a lightweight Index Branch to independently select top-(k) key-value blocks for each GQA group, enabling group-specific sparse retrieval with block-level execution.
  • Significant Efficiency Gains: On a 109B-parameter MoE model, MSA reduces per-token attention FLOPs by (28.4\times) at 1M context and achieves measured wall-clock speedups of (14.2\times) (prefill) and (7.6\times) (decoding) on H800 GPUs.
  • Co-designed GPU Kernels: MSA is paired with custom GPU kernels featuring exp-free Top-K selection, KV-outer sparse attention, and fused KL loss computation to translate theoretical sparsity into practical speedups.
  • Near-Lossless Quality: MSA matches the quality of full GQA attention on a comprehensive suite of language, reasoning, multimodal, and agentic benchmarks when trained from scratch or continued from a pretrained dense checkpoint.
  • Practical Deployment: The inference kernel is open-sourced, and a production-grade multimodal model (MiniMax-M3) using MSA has been publicly released.

Introduction and Theoretical Foundation

Large language models (LLMs) are evolving from short interactions to long-horizon agentic workflows (code reasoning, web navigation, tool orchestration) that require attending to hundreds of thousands or millions of tokens. The quadratic cost (O(N^2)) of softmax attention makes this infeasible at scale.

The paper frames the problem around Grouped Query Attention (GQA) with (H_q) query heads and (H_{kv}) key-value heads (GQA ratio (G = H_q/H_{kv})). The standard causal softmax attention for query position (t) and head (h) is:

[ \boldsymbol{o}t^{(h)} = \sum{i \le t} \alpha_{t,i}^{(h)} \boldsymbol{v}i^{(h)}, \quad \alpha{t,i}^{(h)} = \frac{\exp\left( \langle \boldsymbol{q}_t^{(h)}, \boldsymbol{k}i^{(h)} \rangle / \sqrt{d_h} \right)}{\sum{j \le t} \exp\left( \langle \boldsymbol{q}_t^{(h)}, \boldsymbol{k}_j^{(h)} \rangle / \sqrt{d_h} \right)} \tag{1} ]

with cost (\Theta(2 H_q N^2 d_h)) FLOPs. Sparse attention factors this into an indexer and a restricted attention computation:

[ \mathcal{I}i = \text{Index}\phi(\boldsymbol{q}i, \boldsymbol{K}{\le i}), \quad \boldsymbol{o}_i = \text{Attn}(\boldsymbol{q}_i, \boldsymbol{K}[\mathcal{I}_i], \boldsymbol{V}[\mathcal{I}_i]) \tag{2} ]

MSA operates at block granularity (block size (B_k)) and shares the selected index set (\mathcal{I}_i^{(r)}) among all (G) query heads within a GQA group (r). The key intuition is that per-group independent selection combined with block-level retrieval provides a good trade-off between expressiveness and computational regularity.

Methodology

Architecture

MSA has two branches (Figure 1 in paper):

Index Branch: A lightweight module with one index query head per GQA group and a single shared index key head:

[ \boldsymbol{Q}{\text{idx}} = \boldsymbol{X}\boldsymbol{W}q^{\text{idx}} \in \mathbb{R}^{N \times H{kv} \times d{\text{idx}}}, \quad \boldsymbol{K}_{\text{idx}} = \boldsymbol{X}\boldsymbol{W}k^{\text{idx}} \in \mathbb{R}^{N \times 1 \times d{\text{idx}}} \tag{5} ]

For query position (i) and group (r), token-level scores are aggregated to block-level via max-pooling:

[ S_{\text{idx}, i,j}^{(r)} = \frac{(\boldsymbol{Q}{\text{idx}})i^{(r)} (\boldsymbol{K}{\text{idx}})j^{\top}}{\sqrt{d{\text{idx}}}}, \quad M{\text{idx}, i,b}^{(r)} = \max_{\substack{j \in \mathcal{B}b \ j \le i}} S{\text{idx}, i,j}^{(r)} \tag{6} ]

The top-(k) block indices are selected:

[ \mathcal{I}i^{(r)} = \text{TopK}{b \in {1,\dots,B}}(M_{\text{idx}, i,\cdot}^{(r)}, k) \tag{7} ]

The local block containing position (i) is always included regardless of score.

Main Branch: Performs standard softmax attention restricted to the selected blocks:

[ \boldsymbol{O}_i^{(h)} = \text{softmax}\left( \frac{\boldsymbol{Q}_i^{(h)} (\boldsymbol{K}^{(r)}[\mathcal{I}_i^{(r)}])^{\top}}{\sqrt{d_h}} \right) \boldsymbol{V}^{(r)}[\mathcal{I}_i^{(r)}] \tag{8} ]

Computational Complexity:

  • GQA: (F_{\text{GQA}}(N) = 2 H_q d_h N^2)
  • MSA: (F_{\text{MSA}}(N) = \underbrace{H_{kv} d_{\text{idx}} N^2}{\text{Index Branch}} + \underbrace{4 H_q d_h N k B_k}{\text{Main Branch}}) (Equation 12)

With (k B_k \ll N) and (H_{kv} d_{\text{idx}} \ll H_q d_h), MSA's cost is essentially linear in (N) for the main attention path.

Training

The non-differentiable Top-K selection is trained via a KL divergence loss between the Index Branch distribution and the group-averaged Main Branch distribution on the selected tokens:

[ P_{\text{idx}, i,j}^{(r)} = \frac{\exp(S_{\text{idx}, i,j}^{(r)})}{\sum_{u \in \mathcal{I}{i,\text{tok}}^{(r)}} \exp(S{\text{idx}, i,u}^{(r)})}, \quad P_{i,j}^{(r)} = \frac{1}{G} \sum_{\ell \in H_r} \frac{\exp(S_{i,j}^{(\ell)})}{\sum_{u \in \mathcal{I}{i,\text{tok}}^{(r)}} \exp(S{i,u}^{(\ell)})} \tag{9} ]

[ \mathcal{L}{\text{KL}} = \frac{1}{N H{kv}} \sum_{i=1}^N \sum_{r=1}^{H_{kv}} D_{\text{KL}}\left( \text{stopgrad}(P_{i,\cdot}^{(r)}) \parallel P_{\text{idx}, i,\cdot}^{(r)} \right) \tag{10} ]

Three stabilization mechanisms are employed:

  1. Gradient Detach: Input to Index Branch (\boldsymbol{X}) is stop-gradiented, confining loss updates to (\boldsymbol{W}_q^{\text{idx}}) and (\boldsymbol{W}_k^{\text{idx}}).
  2. Indexer Warmup: Two-stage training — first runs full attention to warm up the indexer, then switches to sparse attention.
  3. Local Block: Always includes the block containing the query token.

The complete layer forward pass is summarized in Algorithm 1:

1: Q, K, V ← XW_q, XW_k, XW_v
2: Q_idx, K_idx ← stopgrad(X)W_q^idx, stopgrad(X)W_k^idx
3: M_idx ← BlockMaxPool(Q_idx, K_idx, B_k)
4: I ← TopK(M_idx, k)   // local block forced included
5: O ← TopKAttn(Q, K, V, I)
6: output ← OW_o
7: L_KL ← KLdiv(Q_idx, K_idx, stopgrad(Q), stopgrad(K), I)
8: return output, L_KL

Kernel Design

Exp-free Top-K selection: Since softmax is order-preserving, raw scores are used directly. A per-thread register top-(k) heuristic with (k=16) and (B_k=128) is implemented, using a min-heap in shared memory with deferred writes and a shuffle merge.

Seq. Len. (N)Blocks (B)(k)torchTileLangOursvs. torchvs. TileLang
128K102416397028647795.1×3.7×
128K2048325378363019912.7×1.8×
512K409616338101777978804.3×2.3×
512K8192325765926100213262.7×1.2×

Table 1 | Top-(k) latency (μs) for fp32 inputs. Our kernel is fastest across all configurations, with largest gains at (k=16).

KV-outer sparse attention: Iterates over KV blocks on the outer loop and gathers queries that selected each block. This gives arithmetic intensity proportional to (\frac{2}{3}B_k) rather than (G) from Q-outer, which is much larger in practice ((B_k=128), (G=16)). Key techniques:

  • Pre-scheduled tile chunking: Hot tiles (selected by many queries) are split into chunks to balance load.
  • Query concatenation: Packs (\lceil 128/G \rceil) query positions per tile to fill 128×128 MMA operations.
  • Two-phase forward: Separate kernels for partial attention and combine (with max-LSE renormalization) to handle the split structure without atomics.

Sparse KL Loss: Fuses LSE computation into the main forward pass and uses persistent grid dynamic load balancing in the backward pass.

Empirical Validation / Results

Setup

  • Model: 109B-parameter MoE (41 layers, 64 query heads, 4 KV heads, (d_h=128), RoPE dim 64, 128 routed experts, top-4)
  • MSA configuration: (B_k=128), (k=16) (budget 2,048 tokens per query)
  • Training: 3T total tokens. Two variants:
    • MSA-PT: Trained from scratch with sparse attention after 40B token warmup.
    • MSA-CPT: Continued pretraining from a 2.6T full-attention checkpoint, with 40B warmup + 360B sparse training.

Training Dynamics

Figure 2 shows LM loss and gradient norm curves for MSA-PT are nearly indistinguishable from full attention over 3T tokens. Figure 3 shows MSA-CPT quickly reduces KL loss during warmup and maintains low KL, high block recall ((\sim)0.85) and high score recall ((\sim)0.90) during sparse continued pretraining.

Downstream Evaluation

Table 2 (excerpted summary): MSA-PT and MSA-CPT remain competitive with full attention. MSA-PT excels on math (GSM8K 77.7 vs 76.2, MathVista 46.8 vs 43.8), image (VisualWebBench 68.4 vs 55.6), video (EgoSchema 37.6 vs 29.6), and retrieval (RULER-8K 84.2 vs 79.8). MSA-CPT preserves full-attention behavior on text/code benchmarks. Agent perplexities (TAU2, AgentCompany, HLE, SWE) are essentially identical across all three models.

Long-Context Extension

Table 3: After 140B tokens of long-context training, MSA-CPT achieves 45.93 vs 46.53 overall on HELMET-128K (-0.60) and 72.12 vs 72.00 on RULER-128K (+0.12), showing near-lossless long-context capability despite attending to only 2,048 tokens per query.

Efficiency

Figure 4 shows the FLOPs reduction and measured speedups. At 1M tokens:

  • FLOPs reduction: (28.4\times)
  • Prefill speedup: (14.2\times)
  • Decoding speedup: (7.6\times)

Speedups increase with context length as dense attention's cost grows quadratically.

Theoretical and Practical Implications

  • Principle of Minimality: MSA demonstrates that a simple, streamlined design (one index query head per GQA group, shared index key head, block-level max-pooling scoring, single KL loss) can match full attention quality with massive compute savings. The ablation studies in appendix confirm that additional components (e.g., separate value head for indexer, multiple index heads) are unnecessary.

  • GQA Compatibility: MSA composes naturally with the GQA backbone used by most current frontier LLMs. This makes the recipe directly transferable—no architectural changes beyond adding two small projection matrices.

  • Practical Scalability: The (28.4\times) compute reduction at 1M tokens directly addresses the binding deployment constraint for long-context applications (code repositories, agentic workflows, persistent memory). The custom kernel design ensures that sparsity translates to wall-clock speedups on current hardware.

  • Training Paradigm: MSA supports both native sparse pretraining (MSA-PT) and lossless conversion from pretrained dense checkpoints (MSA-CPT), offering flexibility for different deployment scenarios. The two-stage warmup schedule provides a clean conversion phase.

Conclusion

MSA introduces a minimal, scalable blockwise sparse attention mechanism for GQA-based LLMs. By attaching a lightweight Index Branch that performs per-group independent top-(k) block selection, the Main Branch computes exact softmax attention over only a fixed budget of tokens (2,048 per query). The Index Branch is trained via a KL alignment loss with detached gradients, ensuring clean separation from the backbone.

Key results:

  • Matches full GQA attention quality on a comprehensive 109B-parameter multimodal MoE model.
  • Reduces per-token attention FLOPs by (28.4\times) at 1M context.
  • Achieves (14.2\times) prefill and (7.6\times) decoding wall-clock speedups on H800 GPUs.

Future directions:

  • Closing the long-context retrieval gap via longer sparse training, larger inference-time budget, or richer indexer scoring.
  • Extending beyond pretraining to RL post-training and agentic deployment, where long-context cost is the dominant constraint.
  • The inference kernel and a production model (MiniMax-M3) are publicly released for community adoption.

Related papers