Summary of "Mixture-of-Depths Attention"
Summary (Overview)
- Core Contribution: Introduces Mixture-of-Depths Attention (MoDA), a novel attention mechanism for Transformers that allows each query to attend to both standard sequence-level Key-Value (KV) pairs and depth KV pairs from preceding layers at the same token position. This addresses the information dilution problem in deep LLMs where informative features from shallow layers get diluted in deeper layers.
- Key Innovation: Unifies sequence and depth retrieval within a single softmax operation, enabling dynamic, data-dependent mixing of information across both dimensions. This provides a more expressive and efficient alternative to fixed cross-layer connections like residual or dense connections.
- Hardware-Efficient Implementation: Proposes a fused, hardware-aware kernel with a chunk-aware depth-KV layout and group-aware indexing. This implementation achieves 97.3% of FlashAttention-2's efficiency at a 64K sequence length, making MoDA practical for long-context training.
- Empirical Results: Demonstrates consistent performance improvements. On 1.5B-parameter models trained on 400B tokens, MoDA improves average validation perplexity by 0.2 across 10 benchmarks and boosts average downstream task performance by 2.11% with a negligible 3.7% FLOPs overhead.
- Architectural Insight: Finds that combining MoDA with post-norm yields better performance than with pre-norm, and that the primary benefit comes from incorporating depth KV from Feed-Forward Network (FFN) layers, not just attention layers.
Introduction and Theoretical Foundation
Scaling model depth is a crucial but challenging dimension for improving Large Language Models (LLMs). While deeper networks offer richer hierarchical computation, standard Transformer architectures suffer from signal degradation or information dilution: as information passes through many residual layers, salient features formed in early layers can be gradually diluted, making them harder for later layers to recover.
Existing solutions have trade-offs:
- Depth Residual (standard ResNet-style): Simple and stable for optimization but compresses all depth history into a single hidden state, leading to information loss.
- Depth Dense (DenseNet-style): Preserves all intermediate states via concatenation, mitigating dilution but incurring prohibitive quadratic parameter and computational growth () at LLM scale.
The paper posits that dynamic, data-dependent mixing (the principle behind attention) is more effective for preserving historical information than fixed aggregation patterns. This motivates extending attention from the sequence dimension to the depth dimension.
MoDA is proposed as a unified mechanism that allows each layer to adaptively "read" useful states from all preceding layers via attention, combining them with the current sequence context in a single operation. It aims to retain the expressivity of dense connections while maintaining hardware efficiency.
Methodology
Mathematical Formulation
The standard Grouped Query Attention (GQA) is defined as: Given input , it is projected into Queries (), Keys (), and Values ():
where , , , and . The attention output for head is:
where maps query heads to shared key-value heads, and is the causal mask.
MoDA Operation: At layer , the query attends to a concatenated KV set containing:
- Sequence KV: The standard keys and values from the current sequence: .
- Depth KV: Historical key-value pairs from all preceding layers for the same token position, where each .
The attention scores are computed over the concatenated dimension and normalized via a unified softmax. The "write" step appends the current layer's projected KV pair to the depth stream for use by subsequent layers. For FFN layers, a lightweight KV projection is applied to the FFN's input to generate its depth KV contribution.
Complexity Analysis
A key advantage of MoDA is its favorable computational profile compared to dense connections.
| Method | Data-Dependent? | Unified Softmax? | Parameter Complexity | Prefilling FLOPs |
|---|---|---|---|---|
| Depth Dense | ✘ | ✘ | ||
| Depth Attention | ✔ | ✘ | ||
| MoDA (Ours) | ✔ | ✔ |
MoDA maintains the linear-in-width FLOPs scaling of Depth Attention () but is more parameter-efficient () because it reuses the sequence attention's query projection, requiring only grouped depth key/value projections.
Hardware-Efficient Kernel Implementation
A naive PyTorch implementation of MoDA suffers from non-contiguous memory access. The paper introduces a fused kernel with three key optimizations (Algorithm 1):
- Flash-Compatible Depth-KV Layout: Flattens the depth cache along a single axis of length , enabling contiguous block reads.
- Chunk-Aware Depth-KV Layout: Groups queries into chunks of size . Each chunk only accesses a local depth-KV region of size , reducing the effective depth span scanned and improving depth utilization from to .
- Group-Aware Indexing: Leverages the GQA structure where adjacent query rows share the same base-time index . This allows reusing the same depth KV blocks within a group, further improving utilization to .
The kernel fuses sequence and depth attention in one forward pass with shared online-softmax states, minimizing HBM traffic. The mask for depth attention ensures a query only attends to its corresponding depth slots:
Empirical Validation / Results
Main Results (Model Scaling)
Models were trained on 400B tokens following the OLMo2 recipe.
Table 4: Downstream Task Performance (Accuracy %)
| Model | PIQA | HellaSwag | WinoGrande | OpenBookQA | BoolQA | SciQ | ARC-E | ARC-C | COPA | MMLU | Average |
|---|---|---|---|---|---|---|---|---|---|---|---|
| OLMo2 (700M) | 73.72 | 58.77 | 55.33 | 35.60 | 56.24 | 89.50 | 66.84 | 33.44 | 77.00 | 24.69 | 57.11 |
| MoDA (700M) | 73.39 | 59.19 | 60.22 | 37.20 | 59.33 | 89.60 | 67.37 | 34.78 | 82.00 | 25.61 | 58.87 |
| OLMo2 (1.5B) | 76.55 | 65.86 | 63.22 | 38.80 | 63.61 | 90.60 | 72.98 | 42.47 | 81.00 | 27.73 | 62.28 |
| MoDA (1.5B) | 76.82 | 66.24 | 65.59 | 41.60 | 67.34 | 92.10 | 72.81 | 46.82 | 85.00 | 29.59 | 64.39 |
- MoDA provides stable gains across model scales (+1.76 avg at 700M, +2.11 avg at 1.5B).
- Improvements are broad-based, seen across commonsense, reasoning, and knowledge tasks.
Table 5: Per-Domain Validation Perplexity (Lower is Better)
| Model | C4 | ICE | m2d2-s2orc | Pile | Wiki-text | Books | CC | peS2o | Stack | Average | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| OLMo2 (700M) | 18.32 | 17.43 | 24.37 | 9.53 | 12.26 | 16.78 | 20.53 | 9.17 | 23.84 | 3.93 | 15.61 |
| MoDA (700M) | 18.29 | 17.24 | 23.64 | 9.48 | 12.06 | 16.58 | 20.52 | 9.14 | 23.75 | 3.90 | 15.46 |
| OLMo2 (1.5B) | 16.16 | 15.37 | 21.10 | 8.45 | 10.41 | 14.19 | 18.13 | 8.19 | 21.21 | 3.57 | 13.67 |
| MoDA (1.5B) | 15.97 | 15.08 | 20.92 | 8.33 | 10.16 | 13.95 | 17.88 | 8.09 | 20.85 | 3.52 | 13.47 |
- MoDA consistently lowers perplexity across all ten validation domains at both model sizes.
Ablation Studies
Table 3: MoDA Variants (700M Model) The study compares different design choices for incorporating depth information:
- Baseline (Vanilla Attention): Only sequence KV.
- + Depth KV: Reuses preceding attention layer's KV as depth KV. Provides significant gains (+1.17 avg downstream) with only 0.12% extra FLOPs.
- + Extra FFN KV Proj.: Adds lightweight KV projections for FFN layers. Yields the best accuracy-efficiency trade-off, improving over the baseline by +1.76 avg downstream.
- + Extra Attn KV Proj.: Adds separate depth KV projections for attention layers. Offers only marginal gains (+0.10) for substantial parameter/FLOPs overhead, indicating saturation.
Key Finding: The primary benefit comes from incorporating FFN layers' depth information. The default MoDA variant uses Sequence KV + Depth KV + Extra FFN KV Projection.
Kernel Efficiency
Table 2: Kernel Efficiency vs. FlashAttention-2 The hardware-aware MoDA kernel achieves high efficiency, especially for long sequences and large GQA group sizes .
| Scaling Dimension | FA2-triton (ms) | MoDA-triton (ms) | Depth Util. () | Extra Time |
|---|---|---|---|---|
| 7.970 | 10.750 | 12.50% | +25.86% | |
| 1831.668 | 1883.026 | 12.50% | +2.73% | |
| 28.982 | 39.741 | 3.12% | +27.07% | |
| 467.107 | 480.767 | 50.00% | +2.84% |
- The relative overhead of MoDA decreases as sequence length increases, dropping to ~2.7% at .
- Larger group sizes improve depth utilization () and reduce overhead.
- Increasing model depth linearly increases MoDA's runtime, as expected.
Table 7: Kernel Optimization Ablation
| Optimization Steps | Time (ms) | Speedup vs. Naive |
|---|---|---|
| (1) Naive PyTorch | 2128.900 | 1x |
| (2) + Flash-Compatible | 13.102 | ~162x |
| (3) + Chunk-Aware | 6.286 | ~339x |
| (4) + Group-Aware | 1.460 | ~1458x |
The combined optimizations yield a ~1458x speedup over the naive implementation.
Additional Analyses
- Layer-Number Analysis: MoDA improves validation loss for both shallow (24-layer) and deep (48-layer) models. Gains are more pronounced with post-norm than pre-norm in deeper models.
- Attention Visualization: Heatmaps (Figure 5) show that models allocate non-trivial and persistent attention mass to the depth-KV block, confirming the mechanism's active use. The attention pattern also appears to differ from typical "attention sink" behavior, distributing probability more broadly.
Theoretical and Practical Implications
- Theoretical: MoDA provides a principled, data-dependent mechanism for mitigating information dilution in deep networks, bridging the gap between the stability of residual connections and the expressivity of dense connections. The unified softmax over sequence and depth creates a coherent representation space for cross-layer information retrieval.
- Practical: The method offers a cost-effective way to improve LLM performance. With only a ~3.7% FLOPs overhead, it delivers consistent gains in perplexity and downstream task accuracy. The provided hardware-efficient kernel makes it feasible for integration into large-scale training pipelines.
- Architectural Guidance: The finding that post-norm works better with MoDA than pre-norm offers a new perspective on normalization choices in depth-scaled Transformers. The result that FFN depth KV is crucial highlights the importance of non-attention features in the depth stream.
Conclusion
Mixture-of-Depths Attention (MoDA) is an effective and efficient architectural innovation for scaling Transformer depth. By enabling dynamic, attentive retrieval of information from all preceding layers, it alleviates the information dilution problem in deep LLMs. The proposed hardware-aware implementation ensures the approach is practical for long-context training. Empirical results across multiple scales and tasks confirm its robustness and performance benefits. MoDA establishes depth-aware attention as a promising primitive for future LLM architecture design.
Future Directions include scaling MoDA for industrial training via advanced CUDA engineering and investigating bounded depth-KV slot caching to mitigate memory bottlenecks when scaling to extreme depths. The mechanism is architecture-agnostic and could benefit multimodal and other Transformer-based models.