# Mixture-of-Depths Attention

> MoDA attention unifies sequence and depth retrieval in one softmax, boosting LLM performance with minimal FLOPs overhead by dynamically mixing information across layers.

- **Source:** [arXiv](https://arxiv.org/abs/2603.15619)
- **Published:** 2026-03-18
- **Permalink:** https://picx.dev/p/7k53tv
- **Whiteboard:** https://picx.dev/p/7k53tv/image

## Summary

# Summary of "Mixture-of-Depths Attention"

## Summary (Overview)
*   **Core Contribution**: Introduces **Mixture-of-Depths Attention (MoDA)**, a novel attention mechanism for Transformers that allows each query to attend to both standard sequence-level Key-Value (KV) pairs and **depth KV pairs** from preceding layers at the same token position. This addresses the **information dilution problem** in deep LLMs where informative features from shallow layers get diluted in deeper layers.
*   **Key Innovation**: Unifies sequence and depth retrieval within a **single softmax operation**, enabling dynamic, data-dependent mixing of information across both dimensions. This provides a more expressive and efficient alternative to fixed cross-layer connections like residual or dense connections.
*   **Hardware-Efficient Implementation**: Proposes a **fused, hardware-aware kernel** with a chunk-aware depth-KV layout and group-aware indexing. This implementation achieves **97.3% of FlashAttention-2's efficiency** at a 64K sequence length, making MoDA practical for long-context training.
*   **Empirical Results**: Demonstrates consistent performance improvements. On 1.5B-parameter models trained on 400B tokens, MoDA improves average validation perplexity by **0.2** across 10 benchmarks and boosts average downstream task performance by **2.11%** with a negligible **3.7% FLOPs overhead**.
*   **Architectural Insight**: Finds that combining MoDA with **post-norm** yields better performance than with pre-norm, and that the primary benefit comes from incorporating depth KV from **Feed-Forward Network (FFN) layers**, not just attention layers.

## Introduction and Theoretical Foundation
Scaling model depth is a crucial but challenging dimension for improving Large Language Models (LLMs). While deeper networks offer richer hierarchical computation, standard Transformer architectures suffer from **signal degradation** or **information dilution**: as information passes through many residual layers, salient features formed in early layers can be gradually diluted, making them harder for later layers to recover.

Existing solutions have trade-offs:
*   **Depth Residual** (standard ResNet-style): Simple and stable for optimization but compresses all depth history into a single hidden state, leading to information loss.
*   **Depth Dense** (DenseNet-style): Preserves all intermediate states via concatenation, mitigating dilution but incurring prohibitive quadratic parameter and computational growth ($O(L^2D^2)$) at LLM scale.

The paper posits that **dynamic, data-dependent mixing** (the principle behind attention) is more effective for preserving historical information than fixed aggregation patterns. This motivates extending attention from the sequence dimension to the **depth dimension**.

**MoDA** is proposed as a unified mechanism that allows each layer to adaptively "read" useful states from all preceding layers via attention, combining them with the current sequence context in a single operation. It aims to retain the expressivity of dense connections while maintaining hardware efficiency.

## Methodology

### Mathematical Formulation
The standard Grouped Query Attention (GQA) is defined as:
Given input $X \in \mathbb{R}^{T \times D}$, it is projected into Queries ($Q$), Keys ($K$), and Values ($V$):
$$
Q = XW_Q, \quad K = XW_K, \quad V = XW_V
$$
where $W_Q \in \mathbb{R}^{D \times (H_q d)}$, $W_K, W_V \in \mathbb{R}^{D \times (H_k d)}$, $H_q = G H_k$, and $D = H_q d$.
The attention output for head $h$ is:
$$
\text{Attention}(Q, K, V) = \text{Concat}_{h=1}^{H_q} \left( \text{softmax}\left( \frac{Q_h K_{\phi(h)}^T}{\sqrt{d}} + M \right) V_{\phi(h)} \right)
$$
where $\phi(h) = \lceil h/G \rceil$ maps query heads to shared key-value heads, and $M$ is the causal mask.

**MoDA Operation**: At layer $l$, the query $Q_{l-1}$ attends to a concatenated KV set containing:
1.  **Sequence KV**: The standard keys and values from the current sequence: $K_l^{seq}, V_l^{seq} \in \mathbb{R}^{T \times (H_k d)}$.
2.  **Depth KV**: Historical key-value pairs from all preceding layers $\{K_i, V_i\}_{i=0}^{l-1}$ for the same token position, where each $K_i, V_i \in \mathbb{R}^{T \times (H_k d)}$.

The attention scores are computed over the concatenated dimension and normalized via a **unified softmax**. The "write" step appends the current layer's projected KV pair $(K_l, V_l)$ to the depth stream for use by subsequent layers. For FFN layers, a lightweight KV projection is applied to the FFN's input to generate its depth KV contribution.

### Complexity Analysis
A key advantage of MoDA is its favorable computational profile compared to dense connections.

| Method | Data-Dependent? | Unified Softmax? | Parameter Complexity | Prefilling FLOPs |
| :--- | :--- | :--- | :--- | :--- |
| **Depth Dense** | ✘ | ✘ | $O(L^2 D^2)$ | $O(T L^2 D^2)$ |
| **Depth Attention** | ✔ | ✘ | $O(L D^2)$ | $O(T L^2 D)$ |
| **MoDA (Ours)** | ✔ | ✔ | $O(L D^2 / G)$ | $O(T L^2 D)$ |

MoDA maintains the linear-in-width FLOPs scaling of Depth Attention ($O(T L^2 D)$) but is more parameter-efficient ($O(L D^2 / G)$) because it reuses the sequence attention's query projection, requiring only grouped depth key/value projections.

### Hardware-Efficient Kernel Implementation
A naive PyTorch implementation of MoDA suffers from non-contiguous memory access. The paper introduces a fused kernel with three key optimizations (Algorithm 1):
1.  **Flash-Compatible Depth-KV Layout**: Flattens the depth cache along a single axis of length $T \times L$, enabling contiguous block reads.
2.  **Chunk-Aware Depth-KV Layout**: Groups queries into chunks of size $C$. Each chunk only accesses a local depth-KV region of size $C \times L$, reducing the effective depth span scanned and improving **depth utilization** from $1/T$ to $1/C$.
3.  **Group-Aware Indexing**: Leverages the GQA structure where $G$ adjacent query rows share the same base-time index $t_{base}(i_q) = \lfloor i_q / G \rfloor$. This allows reusing the same depth KV blocks within a group, further improving utilization to $G/C$.

The kernel fuses sequence and depth attention in one forward pass with shared online-softmax states, minimizing HBM traffic. The mask for depth attention ensures a query only attends to its corresponding depth slots:
$$
\text{mask}(i_q, j_d) = \mathbb{1}\left[ \left\lfloor \frac{i_q}{G} \right\rfloor = \left\lfloor \frac{j_d}{L} \right\rfloor \right]
$$

## Empirical Validation / Results

### Main Results (Model Scaling)
Models were trained on 400B tokens following the OLMo2 recipe.

**Table 4: Downstream Task Performance (Accuracy %)**
| Model | PIQA | HellaSwag | WinoGrande | OpenBookQA | BoolQA | SciQ | ARC-E | ARC-C | COPA | MMLU | **Average** |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| **OLMo2 (700M)** | 73.72 | 58.77 | 55.33 | 35.60 | 56.24 | 89.50 | 66.84 | 33.44 | 77.00 | 24.69 | **57.11** |
| **MoDA (700M)** | 73.39 | 59.19 | 60.22 | 37.20 | 59.33 | 89.60 | 67.37 | 34.78 | 82.00 | 25.61 | **58.87** |
| **OLMo2 (1.5B)** | 76.55 | 65.86 | 63.22 | 38.80 | 63.61 | 90.60 | 72.98 | 42.47 | 81.00 | 27.73 | **62.28** |
| **MoDA (1.5B)** | 76.82 | 66.24 | 65.59 | 41.60 | 67.34 | 92.10 | 72.81 | 46.82 | 85.00 | 29.59 | **64.39** |

*   MoDA provides **stable gains** across model scales (+1.76 avg at 700M, +2.11 avg at 1.5B).
*   Improvements are **broad-based**, seen across commonsense, reasoning, and knowledge tasks.

**Table 5: Per-Domain Validation Perplexity (Lower is Better)**
| Model | C4 | ICE | m2d2-s2orc | Pile | Wiki-text | Books | CC | peS2o | Reddit | Stack | **Average** |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| **OLMo2 (700M)** | 18.32 | 17.43 | 24.37 | 9.53 | 12.26 | 16.78 | 20.53 | 9.17 | 23.84 | 3.93 | **15.61** |
| **MoDA (700M)** | 18.29 | 17.24 | 23.64 | 9.48 | 12.06 | 16.58 | 20.52 | 9.14 | 23.75 | 3.90 | **15.46** |
| **OLMo2 (1.5B)** | 16.16 | 15.37 | 21.10 | 8.45 | 10.41 | 14.19 | 18.13 | 8.19 | 21.21 | 3.57 | **13.67** |
| **MoDA (1.5B)** | 15.97 | 15.08 | 20.92 | 8.33 | 10.16 | 13.95 | 17.88 | 8.09 | 20.85 | 3.52 | **13.47** |

*   MoDA consistently **lowers perplexity** across all ten validation domains at both model sizes.

### Ablation Studies
**Table 3: MoDA Variants (700M Model)**
The study compares different design choices for incorporating depth information:
1.  **Baseline (Vanilla Attention)**: Only sequence KV.
2.  **+ Depth KV**: Reuses preceding attention layer's KV as depth KV. Provides significant gains (+1.17 avg downstream) with only 0.12% extra FLOPs.
3.  **+ Extra FFN KV Proj.**: Adds lightweight KV projections for FFN layers. Yields the best accuracy-efficiency trade-off, improving over the baseline by +1.76 avg downstream.
4.  **+ Extra Attn KV Proj.**: Adds separate depth KV projections for attention layers. Offers only marginal gains (+0.10) for substantial parameter/FLOPs overhead, indicating saturation.

**Key Finding**: The primary benefit comes from incorporating **FFN layers' depth information**. The default MoDA variant uses **Sequence KV + Depth KV + Extra FFN KV Projection**.

### Kernel Efficiency
**Table 2: Kernel Efficiency vs. FlashAttention-2**
The hardware-aware MoDA kernel achieves high efficiency, especially for long sequences and large GQA group sizes $G$.

| Scaling Dimension | FA2-triton (ms) | MoDA-triton (ms) | Depth Util. ($\eta_{depth}$) | Extra Time |
| :--- | :--- | :--- | :--- | :--- |
| $T=4096, G=8, L=64$ | 7.970 | 10.750 | 12.50% | +25.86% |
| $T=65536, G=8, L=64$ | 1831.668 | 1883.026 | 12.50% | **+2.73%** |
| $T=16384, G=2, L=64$ | 28.982 | 39.741 | 3.12% | +27.07% |
| $T=16384, G=32, L=64$ | 467.107 | 480.767 | 50.00% | **+2.84%** |

*   The **relative overhead of MoDA decreases as sequence length $T$ increases**, dropping to **~2.7%** at $T=65536$.
*   Larger group sizes $G$ improve depth utilization ($\eta_{depth}=G/C$) and reduce overhead.
*   Increasing model depth $L$ linearly increases MoDA's runtime, as expected.

**Table 7: Kernel Optimization Ablation**
| Optimization Steps | Time (ms) | Speedup vs. Naive |
| :--- | :--- | :--- |
| (1) Naive PyTorch | 2128.900 | 1x |
| (2) + Flash-Compatible | 13.102 | ~162x |
| (3) + Chunk-Aware | 6.286 | ~339x |
| (4) + Group-Aware | **1.460** | **~1458x** |

The combined optimizations yield a **~1458x speedup** over the naive implementation.

### Additional Analyses
*   **Layer-Number Analysis**: MoDA improves validation loss for both shallow (24-layer) and deep (48-layer) models. Gains are more pronounced with **post-norm** than pre-norm in deeper models.
*   **Attention Visualization**: Heatmaps (Figure 5) show that models allocate **non-trivial and persistent attention mass to the depth-KV block**, confirming the mechanism's active use. The attention pattern also appears to differ from typical "attention sink" behavior, distributing probability more broadly.

## Theoretical and Practical Implications
*   **Theoretical**: MoDA provides a principled, **data-dependent** mechanism for mitigating information dilution in deep networks, bridging the gap between the stability of residual connections and the expressivity of dense connections. The unified softmax over sequence and depth creates a coherent representation space for cross-layer information retrieval.
*   **Practical**: The method offers a **cost-effective way to improve LLM performance**. With only a **~3.7% FLOPs overhead**, it delivers consistent gains in perplexity and downstream task accuracy. The provided **hardware-efficient kernel** makes it feasible for integration into large-scale training pipelines.
*   **Architectural Guidance**: The finding that **post-norm works better with MoDA** than pre-norm offers a new perspective on normalization choices in depth-scaled Transformers. The result that **FFN depth KV is crucial** highlights the importance of non-attention features in the depth stream.

## Conclusion
Mixture-of-Depths Attention (MoDA) is an effective and efficient architectural innovation for scaling Transformer depth. By enabling dynamic, attentive retrieval of information from all preceding layers, it alleviates the information dilution problem in deep LLMs. The proposed hardware-aware implementation ensures the approach is practical for long-context training. Empirical results across multiple scales and tasks confirm its robustness and performance benefits. MoDA establishes **depth-aware attention as a promising primitive for future LLM architecture design**.

**Future Directions** include scaling MoDA for industrial training via advanced CUDA engineering and investigating **bounded depth-KV slot caching** to mitigate memory bottlenecks when scaling to extreme depths. The mechanism is architecture-agnostic and could benefit multimodal and other Transformer-based models.

---

_Markdown view of https://picx.dev/p/7k53tv, served by PicX — AI-generated visual whiteboard summaries of research papers._
