Summary

  • GQE introduces mixture-of-experts (MoE) into grouped-query attention (GQA): For each GQA group, a router selects the top-(k) query-head experts per token, while all key-value (KV) heads remain dense and are always computed.
  • Compute reduction without quality loss: At 250M parameters and a 30B-token budget, GQE matches the all-active GQA baseline in downstream accuracy on HellaSwag, ARC-Easy, and PIQA while activating only half of the routed query-head experts (9 out of 16 total query-attention computations when including the always-on shared head).
  • Significant throughput improvement for long contexts: Measured prefill speedups of roughly (1.7\times) to (1.8\times) over the GQA baseline at sequence lengths from 4k to 1024k tokens.
  • Critical design components: The router must receive a proper learning signal through a renormalized weighted-sum slot and be stabilized by an always-on shared attention head; without these, sparse routing degrades accuracy.

Introduction and Theoretical Foundation

Standard transformer self-attention scales quadratically with sequence length and uniformly activates all attention heads for every token. This is inefficient because tokens vary in information content: content-bearing words may need specialized heads, while low-information tokens like punctuation do not. The paper asks whether the conditional computation idea behind mixture-of-experts (MoE) – typically applied to MLP blocks – can be transferred to the attention block.

The authors build on Grouped-Query Attention (GQA) [6], which reduces KV-cache memory by sharing keys and values across groups of query heads. However, GQA still evaluates all query heads for every token. The proposed method, Grouped Query Experts (GQE), retains the dense KV path of GQA but makes the query-head computation conditional: within each fixed GQA group, a router selects (k) out of (M) query-head experts for each token. This preserves the memory efficiency of GQA while reducing active query-head computation. The motivation is that if different tokens need different attention heads, the model should not spend the same compute on every token.

Methodology

Setup

Let (X \in \mathbb{R}^{n \times d}) be a sequence of token representations. In GQA, the (H) query heads are partitioned into (G) groups, each group sharing one KV head. GQE treats the (M = N/G) query heads within each group as experts ({E_{g,1}, \dots, E_{g,M}}) that all attend against the same KV head of group (g).

Within-Group Routing

For each token representation (x_i) and group (g), a router produces scores over the (M) experts, followed by softmax:

[ p_{i,g} = \text{softmax}(r_g(x_i)) \in \mathbb{R}^M ]

Then the top-(k) experts are selected per group:

[ \mathcal{K}_g(x_i) = \text{TopK}m(p{i,g,m}, k) ]

Output Construction with a Shared Head

The selected expert outputs are concatenated as ordinary slots (hard routing, no weighting):

[ O_i = { o_{i,g,m} : g \in {1,\dots,G}, m \in \mathcal{K}_g(x_i) } ]

To provide a differentiable learning signal to the router, additionally a renormalized weighted sum of the selected expert outputs is computed:

[ \bar{o}i = \sum{g=1}^G \sum_{m \in \mathcal{K}g(x_i)} w{i,g,m} o_{i,g,m}, \quad w_{i,g,m} = \frac{p_{i,g,m}}{\sum_{g'=1}^G \sum_{m' \in \mathcal{K}{g'}(x_i)} p{i,g',m'}} ]

An always-on shared attention head (s(x_i)) is also computed for every token, providing a stable, token-independent pathway.

The final output concatenates the hard-selected expert slots, the weighted-sum slot, and the shared head slot before the output projection (W^O):

[ y_i = \text{Concat}( O_i, \bar{o}_i, s(x_i) ) W^O ]

For the main 16-query-head / 8-KV-head setting with (k=1), this yields 10 slots (8 hard, 1 weighted-sum, 1 shared) compared to 16 slots in the dense baseline, reshaping (W^O) accordingly.

Routing Auxiliary Loss

A standard load-balancing auxiliary loss is applied to prevent router collapse, encouraging balanced selection across experts within each group.

Compute Profile

With (k) active experts per group and (G) groups, exactly (kG) routed query experts are active per token. Including the always-on shared head, the total active fraction relative to the (N) routed expert pool is:

[ \frac{kG + 1}{N} ]

For the main setting ((N=16, G=8, k=1)), this is (9/16 \approx 56%).

Empirical Validation / Results

Experimental Setup

All models are trained on a fixed 30B-token sample from FineWeb-Edu at the 250M-parameter scale. A GQA baseline with 8 KV heads (16 query heads) is compared with GQE variants that add routing. The per-head dimension is fixed. Training hyperparameters are given in Table 1 (see paper). ZClip is used to mitigate loss spikes.

Accuracy

Table 2 from the paper shows the main accuracy results:

VariantHellaSwagARC-EPIQAAverage
GQA baseline (all 16 heads active)41.3161.3664.9055.86
Weighted concat, no renormalized slot40.1660.5264.8555.18
Hard concat only40.6660.5665.0755.43
GQE (renorm. scoring + shared head)41.0162.4164.6956.04

The corrected GQE variant matches (marginally exceeds) the all-active GQA baseline while activating 8 of 16 routed query heads (9 of 16 total with shared head). Variants without renormalized scoring or without the shared head fall below the baseline, confirming the necessity of these design choices.

Throughput

Figure 1 (described in the paper) shows measured prefill speedup (GQA baseline latency / GQE latency) across sequence lengths. At 2k tokens, speedup is modest ((\approx 1.15\times)). From 4k tokens onward, speedup stabilizes between (1.67\times) and (1.80\times), reflecting the increasing benefit as query-side attention work dominates fixed routing overhead.

Theoretical and Practical Implications

  • Efficiency: GQE demonstrates that self-attention can be made conditionally sparse by routing tokens to only a subset of query-head experts within GQA groups, while retaining the memory savings of GQA's KV-cache. The speedup grows with sequence length, making it relevant for long-context transformers.
  • Quality: The method matches dense GQA quality when properly designed. The ablation studies reveal that simply applying sparse routing without a differentiable learning signal (renormalized weighted-sum) and without a stable shared head leads to degradation. This provides architectural guidelines for future conditional attention designs.
  • Parameter count: GQE is not exactly parameter-matched to GQA because the output projection (W^O) has fewer input slots. The authors acknowledge this as a caveat; the comparison controls for training budget, data, KV layout, and per-head dimension, but not for the output projection size.
  • Scalability: Current results are at 250M parameters and 30B tokens; scaling to larger models and longer training budgets is left for future work.

Conclusion

GQE applies mixture-of-experts to grouped-query attention by making query-head computation conditional: within each GQA group, a router selects the top-(k) query-head experts per token, while all KV heads remain dense. The method matches the all-active GQA baseline in downstream accuracy while activating half of the routed query-head experts (9 of 16 total query-attention computations), achieving up to (1.8\times) speedup for long contexts. The success depends critically on a renormalized weighted-sum slot to provide a differentiable signal to the router and an always-on shared head for stability. Future work should validate at larger scales and compare with other efficient attention architectures.

Related papers