Summary
- GQE introduces mixture-of-experts (MoE) into grouped-query attention (GQA): For each GQA group, a router selects the top-(k) query-head experts per token, while all key-value (KV) heads remain dense and are always computed.
- Compute reduction without quality loss: At 250M parameters and a 30B-token budget, GQE matches the all-active GQA baseline in downstream accuracy on HellaSwag, ARC-Easy, and PIQA while activating only half of the routed query-head experts (9 out of 16 total query-attention computations when including the always-on shared head).
- Significant throughput improvement for long contexts: Measured prefill speedups of roughly (1.7\times) to (1.8\times) over the GQA baseline at sequence lengths from 4k to 1024k tokens.
- Critical design components: The router must receive a proper learning signal through a renormalized weighted-sum slot and be stabilized by an always-on shared attention head; without these, sparse routing degrades accuracy.
Introduction and Theoretical Foundation
Standard transformer self-attention scales quadratically with sequence length and uniformly activates all attention heads for every token. This is inefficient because tokens vary in information content: content-bearing words may need specialized heads, while low-information tokens like punctuation do not. The paper asks whether the conditional computation idea behind mixture-of-experts (MoE) – typically applied to MLP blocks – can be transferred to the attention block.
The authors build on Grouped-Query Attention (GQA) [6], which reduces KV-cache memory by sharing keys and values across groups of query heads. However, GQA still evaluates all query heads for every token. The proposed method, Grouped Query Experts (GQE), retains the dense KV path of GQA but makes the query-head computation conditional: within each fixed GQA group, a router selects (k) out of (M) query-head experts for each token. This preserves the memory efficiency of GQA while reducing active query-head computation. The motivation is that if different tokens need different attention heads, the model should not spend the same compute on every token.
Methodology
Setup
Let (X \in \mathbb{R}^{n \times d}) be a sequence of token representations. In GQA, the (H) query heads are partitioned into (G) groups, each group sharing one KV head. GQE treats the (M = N/G) query heads within each group as experts ({E_{g,1}, \dots, E_{g,M}}) that all attend against the same KV head of group (g).
Within-Group Routing
For each token representation (x_i) and group (g), a router produces scores over the (M) experts, followed by softmax:
[ p_{i,g} = \text{softmax}(r_g(x_i)) \in \mathbb{R}^M ]
Then the top-(k) experts are selected per group:
[ \mathcal{K}_g(x_i) = \text{TopK}m(p{i,g,m}, k) ]
Output Construction with a Shared Head
The selected expert outputs are concatenated as ordinary slots (hard routing, no weighting):
[ O_i = { o_{i,g,m} : g \in {1,\dots,G}, m \in \mathcal{K}_g(x_i) } ]
To provide a differentiable learning signal to the router, additionally a renormalized weighted sum of the selected expert outputs is computed:
[ \bar{o}i = \sum{g=1}^G \sum_{m \in \mathcal{K}g(x_i)} w{i,g,m} o_{i,g,m}, \quad w_{i,g,m} = \frac{p_{i,g,m}}{\sum_{g'=1}^G \sum_{m' \in \mathcal{K}{g'}(x_i)} p{i,g',m'}} ]
An always-on shared attention head (s(x_i)) is also computed for every token, providing a stable, token-independent pathway.
The final output concatenates the hard-selected expert slots, the weighted-sum slot, and the shared head slot before the output projection (W^O):
[ y_i = \text{Concat}( O_i, \bar{o}_i, s(x_i) ) W^O ]
For the main 16-query-head / 8-KV-head setting with (k=1), this yields 10 slots (8 hard, 1 weighted-sum, 1 shared) compared to 16 slots in the dense baseline, reshaping (W^O) accordingly.
Routing Auxiliary Loss
A standard load-balancing auxiliary loss is applied to prevent router collapse, encouraging balanced selection across experts within each group.
Compute Profile
With (k) active experts per group and (G) groups, exactly (kG) routed query experts are active per token. Including the always-on shared head, the total active fraction relative to the (N) routed expert pool is:
[ \frac{kG + 1}{N} ]
For the main setting ((N=16, G=8, k=1)), this is (9/16 \approx 56%).
Empirical Validation / Results
Experimental Setup
All models are trained on a fixed 30B-token sample from FineWeb-Edu at the 250M-parameter scale. A GQA baseline with 8 KV heads (16 query heads) is compared with GQE variants that add routing. The per-head dimension is fixed. Training hyperparameters are given in Table 1 (see paper). ZClip is used to mitigate loss spikes.
Accuracy
Table 2 from the paper shows the main accuracy results:
| Variant | HellaSwag | ARC-E | PIQA | Average |
|---|---|---|---|---|
| GQA baseline (all 16 heads active) | 41.31 | 61.36 | 64.90 | 55.86 |
| Weighted concat, no renormalized slot | 40.16 | 60.52 | 64.85 | 55.18 |
| Hard concat only | 40.66 | 60.56 | 65.07 | 55.43 |
| GQE (renorm. scoring + shared head) | 41.01 | 62.41 | 64.69 | 56.04 |
The corrected GQE variant matches (marginally exceeds) the all-active GQA baseline while activating 8 of 16 routed query heads (9 of 16 total with shared head). Variants without renormalized scoring or without the shared head fall below the baseline, confirming the necessity of these design choices.
Throughput
Figure 1 (described in the paper) shows measured prefill speedup (GQA baseline latency / GQE latency) across sequence lengths. At 2k tokens, speedup is modest ((\approx 1.15\times)). From 4k tokens onward, speedup stabilizes between (1.67\times) and (1.80\times), reflecting the increasing benefit as query-side attention work dominates fixed routing overhead.
Theoretical and Practical Implications
- Efficiency: GQE demonstrates that self-attention can be made conditionally sparse by routing tokens to only a subset of query-head experts within GQA groups, while retaining the memory savings of GQA's KV-cache. The speedup grows with sequence length, making it relevant for long-context transformers.
- Quality: The method matches dense GQA quality when properly designed. The ablation studies reveal that simply applying sparse routing without a differentiable learning signal (renormalized weighted-sum) and without a stable shared head leads to degradation. This provides architectural guidelines for future conditional attention designs.
- Parameter count: GQE is not exactly parameter-matched to GQA because the output projection (W^O) has fewer input slots. The authors acknowledge this as a caveat; the comparison controls for training budget, data, KV layout, and per-head dimension, but not for the output projection size.
- Scalability: Current results are at 250M parameters and 30B tokens; scaling to larger models and longer training budgets is left for future work.
Conclusion
GQE applies mixture-of-experts to grouped-query attention by making query-head computation conditional: within each GQA group, a router selects the top-(k) query-head experts per token, while all KV heads remain dense. The method matches the all-active GQA baseline in downstream accuracy while activating half of the routed query-head experts (9 of 16 total query-attention computations), achieving up to (1.8\times) speedup for long contexts. The success depends critically on a renormalized weighted-sum slot to provide a differentiable signal to the router and an always-on shared head for stability. Future work should validate at larger scales and compare with other efficient attention architectures.
Related papers
- On the Geometry of On-Policy Distillation
On-policy distillation exhibits subspace locking, with cumulative updates confined to a persistent low-dimensional channel controlled by objective composition.
- Zone of Proximal Policy Optimization: Teacher in Prompts, Not Gradients
ZPPO injects teacher knowledge only into prompts via BCQ and NCQ on hard questions, outperforming distillation and RL at small scales.
- Beyond Uniform Token-Level Trust Region in LLM Reinforcement Learning
CPPO replaces uniform token-level trust regions with position-weighted thresholds and cumulative prefix budgets, achieving state-of-the-art AIME results across Qwen3 models.