Mean Mode Screaming: Mean–Variance Split Residuals for 1000-Layer Diffusion Transformers

Summary (Overview)

  • Identifies Mean Mode Screaming (MMS): A structural failure mode in ultra-deep Diffusion Transformers where networks enter a silent, mean-dominated collapse state, homogenizing token representations and suppressing centered variation.
  • Mechanistic Analysis: Shows MMS is triggered by an O(T)O(T) coherent mean-coherent gradient component in residual writers, compounded by Q/K gradient suppression via the Softmax Jacobian's null space when values homogenize.
  • Proposes MV-Split Residuals: A new residual design that decouples mean and centered gradient paths, combining a separately gained centered residual update with a leaky trunk-mean replacement to prevent collapse.
  • Empirical Validation: MV-Split prevents divergence in 400-layer DiTs, converges faster than isotropic gating methods like LayerScale, and enables stable training of a 1000-layer DiT as a scale-validation run.

Introduction and Theoretical Foundation

Scaling Diffusion Transformers (DiTs) to hundreds of layers introduces reliability issues beyond standard gradient heuristics. Networks can abruptly diverge into a mean-dominated collapse state, where token representations homogenize and centered variation is suppressed. This paper isolates the trigger event as Mean Mode Screaming (MMS), a mean-coherent backward shock on residual writers that opens deep residual branches and drives the network into collapse.

The failure exploits a geometric asymmetry in token space. Row-stochastic attention preserves pure-mean states but can become contractive on the centered subspace. The network relies on residual branches to replenish centered variation; if these updates become mean-dominated, collapse occurs.

Existing stabilizers (ReZero, LayerScale) apply isotropic token-space gating, damping both mean and centered components together, which slows convergence. This motivates MV-Split Residuals, which target the unstable mean mode selectively.

Methodology

Backbone: A stripped-down single-stream DiT with Post-Norm residual chain (Xl+1=RMSNorm(Xl+fl(Xl))X_{l+1} = \text{RMSNorm}(X_l + f_l(X_l))) and no AdaLN, to isolate deep residual propagation. Image and text tokens are concatenated into a unified sequence for self-attention.

Initialization: Zero-initialization of residual writers (WOW_O and W2W_2) in main runs, with internal parameters at standard initialization.

Training Objective: Rectified Flow matching:

L=Et,x0,x1[vθ(zt,Xtxt)(x0x1)2]L = E_{t,x_0,x_1} \left[ \| v_\theta(z_t, X_{\text{txt}}) - (x_0 - x_1) \|^2 \right]

where zt=(1t)x0+tx1z_t = (1-t)x_0 + tx_1.

Token-Space Decomposition: For sequence XRT×DX \in \mathbb{R}^{T \times D}, define:

X=JX+PXμ(X)+c(X)X = JX + PX \equiv \mu(X) + c(X)

where J=1T11J = \frac{1}{T}11^\top, P=IJP = I - J, μ(X)\mu(X) is the sequence-mean, and c(X)c(X) is centered variation.

Key Theoretical Results:

  • Proposition 1: For row-stochastic AA with A1=1A1 = 1, Aμ(X)=μ(X)A\mu(X) = \mu(X).
  • Proposition 2: c(AX)=PAPXc(AX) = PAPX, and c(AX)FPAP2c(X)F\|c(AX)\|_F \leq \|PAP\|_2 \|c(X)\|_F. Define μeff(A)=PAP2\mu_{\text{eff}}(A) = \|PAP\|_2.

Empirical Validation / Results

Failure Dynamics Tracing (Figure 3):

  1. Backward trigger: Gradient spike concentrated in mean-coherent component GmeanG_{\text{mean}} while centered component GctrG_{\text{ctr}} shows no amplification. Q/K gradients drop by ~4 orders of magnitude.
  2. Forward lock-in: Residual branch opens, mean/centered energy ratio ρT\rho_T rises sharply. Deep attention remains contractive (μeff<1\mu_{\text{eff}} < 1). Token representations homogenize (cosine similarity → 1).

Gradient Decomposition Mechanism: For token-wise linear map WW, gradient decomposes exactly:

WL=Tδˉyˉ(ΔWμ,mean-coherent,O(T) when aligned)+t=1Tδ~ty~t(ΔWc,centered, diffusive)\nabla_W L = T\bar{\delta}\bar{y}^\top \quad (\Delta W_\mu, \text{mean-coherent}, O(T) \text{ when aligned}) + \sum_{t=1}^T \tilde{\delta}_t \tilde{y}_t^\top \quad (\Delta W_c, \text{centered, diffusive})

MMS occurs when signed cancellation across tokens disappears, allowing ΔWμ\Delta W_\mu to approach its O(T)O(T) regime.

Alignment-Amplification Law (Eq. 6):

WLF2tδt2yt21=st(δsδt)(ysyt)tδt2yt2(T1)Est[cos(ys,yt)cos(δs,δt)]\frac{\|\nabla_W L\|_F^2}{\sum_t \|\delta_t\|^2 \|y_t\|^2} - 1 = \frac{\sum_{s \neq t} (\delta_s^\top \delta_t)(y_s^\top y_t)}{\sum_t \|\delta_t\|^2 \|y_t\|^2} \approx (T-1) E_{s \neq t}[\cos(y_s, y_t) \cos(\delta_s, \delta_t)]

Figure 4 validates this: at spike step t=3400t^*=3400, active layers approach saturation envelope (A1167A-1 \approx 167, ~13× amplification).

MV-Split Residuals Design: Replace standard Post-Norm merge with subspace-routed merge:

Zl=Xl+β(PFl)(centered path)+αJ(FlXl)(mean path)Z_l = X_l + \beta \odot (PF_l) \quad (\text{centered path}) + \alpha \odot J(F_l - X_l) \quad (\text{mean path}) Xl+1=RMSNorm(Zl)X_{l+1} = \text{RMSNorm}(Z_l)

where α,βRD\alpha, \beta \in \mathbb{R}^D are per-block learnable vectors. This yields decoupled pre-normalization merge:

PZl=PXl+β(PFl),JZl=(1α)(JXl)+α(JFl)PZ_l = PX_l + \beta \odot (PF_l), \quad JZ_l = (1-\alpha) \odot (JX_l) + \alpha \odot (JF_l)

Performance Comparison:

Table 1: Stability and convergence across 400-/1000-layer DiT runs

MethodFID ↓ / IS ↑ @10k@20k@30k@40k@50k
400L Base (η\eta)
400L Base (η/2\eta/2) †5.92 / 108.63.22 / 152.2
400L LayerScale14.08 / 59.26.50 / 96.64.09 / 130.53.33 / 149.62.90 / 165.5
400L MV-Split7.23 / 89.83.64 / 139.93.09 / 166.52.79 / 182.02.60 / 185.5
1000L MV-Split5.47 / 117.32.92 / 178.22.68 / 196.62.64 / 209.42.77 / 217.3

MV-Split achieves the best stable 400-layer results, avoiding divergence while converging faster than LayerScale.

Writer-Gradient Mode Analysis (Figure 6):

  • LayerScale: Compresses both GmeanG_{\text{mean}} and GctrG_{\text{ctr}} (isotropic gating).
  • MV-Split: Bounds GmeanG_{\text{mean}} while preserving a higher, stable GctrG_{\text{ctr}} band (mode-selective control).

Theoretical and Practical Implications

Theoretical Implications:

  • Provides a mechanistic explanation for collapse in ultra-deep DiTs via exact gradient decomposition and Softmax null-space suppression.
  • Identifies MMS as a distinct failure pathway from general training spikes or rank collapse.
  • Demonstrates that token-mean subspace can act as an implicit global timestep carrier (Appendix H).

Practical Implications:

  • MV-Split Residuals offer a new stabilization method that decouples mean and centered paths, enabling deeper DiTs without sacrificing convergence speed.
  • Enables stable 1000-layer DiT training, validating scalability at boundary depths.
  • Shifts the stability-constrained quality frontier for deep generative models.

Conclusion

Mean Mode Screaming (MMS) is a specific residual-subspace failure pathway in ultra-deep Post-Norm DiTs, driven by coherent O(T)O(T) mean-coherent gradient accumulation and subsequent Q/K gradient suppression. MV-Split Residuals address this by damping the mean path without shrinking the centered path, preventing collapse while maintaining fast convergence. The method enables stable training at 400 layers and scales to 1000 layers, establishing a viable path for extreme-depth Diffusion Transformers.