Elucidating the SNR-t Bias of Diffusion Probabilistic Models
Summary (Overview)
- Identifies a novel SNR-t Bias: Diffusion Probabilistic Models (DPMs) suffer from a Signal-to-Noise Ratio-timestep (SNR-t) bias, where the SNR of the predicted sample during inference does not match its assigned timestep, unlike the strict coupling during training.
- Provides Empirical and Theoretical Evidence: Demonstrates that reverse denoising samples consistently have lower SNR than their forward counterparts at the same timestep, leading to inaccurate network predictions. A theoretical proof formalizes this bias.
- Proposes a Training-Free Correction Method: Introduces a Differential Correction in Wavelet domain (DCW) method. It uses the difference signal between the predicted and reconstructed samples to guide biased samples toward the ideal perturbed distribution, applying dynamic corrections to different frequency components.
- Achieves Broad Improvements with Negligible Overhead: The plug-and-play DCW method significantly improves the generation quality (FID, Recall) of various DPMs (e.g., IDDPM, ADM, EDM, FLUX) across multiple datasets and resolutions, and can further enhance state-of-the-art bias-corrected models, with minimal computational cost (~0.1-0.5% overhead).
Introduction and Theoretical Foundation
Diffusion Probabilistic Models (DPMs) have achieved remarkable success in generative tasks. They operate via a forward noising process and a reverse denoising process. During training, a perturbed sample and its timestep are strictly coupled, with the sample's SNR defined as:
where , , and is the noise schedule.
However, the authors identify a fundamental SNR-t bias: during inference, cumulative errors from model predictions and numerical solvers cause the denoising trajectory to deviate, breaking the SNR-t coupling. This is distinct from previously studied "exposure bias" (inter-sample discrepancies), as SNR-t bias is an intra-sample misalignment between the sample and its timestep condition.
Methodology
Background on DPMs
The forward process perturbs a data sample with Gaussian noise:
which can be sampled as:
The reverse process is learned by a neural network to predict the noise, trained with the objective:
Empirical Analysis of SNR-t Bias
Two key findings are established:
- Key Finding 1: A network conditioned on timestep produces inaccurate predictions when given a sample with a mismatched SNR. Samples with lower SNR () cause overestimated noise predictions; samples with higher SNR () cause underestimated predictions.
- Key Finding 2: Reverse denoising samples consistently exhibit lower SNR than forward samples at the same timestep , evidenced by always being larger than .
Theoretical Proof of SNR-t Bias
The authors propose Assumption 5.1 to model the reconstruction sample (which predicts from ):
where represents information loss. Based on this, they derive Theorem 5.1, giving the analytical SNR of the biased reverse sample :
This SNR is always lower than the forward process SNR , proving the existence of SNR-t bias. A more concise form is:
Differential Correction in Wavelet Domain (DCW)
To mitigate the bias, the method leverages the differential signal between the predicted sample and the reconstruction , which contains directional information pointing towards the ideal perturbed sample . The core correction in pixel space is:
where is a guidance factor.
To align with DPMs' coarse-to-fine denoising behavior (low-frequency first, high-frequency details later) and reduce noise interference, DCW operates in the wavelet domain. The sample and its reconstruction are decomposed via Discrete Wavelet Transform (DWT) into four frequency subbands . Correction is applied per subband:
The corrected subbands are then recomposed via inverse DWT (iDWT).
Dynamic weights are assigned using the reverse process variance as a progress indicator:
- Low-frequency: (decaying schedule)
- High-frequency: (increasing schedule)
The overall DCW framework is illustrated below:
Figure 2. The overall framework of Differential Correction in Wavelet domain (DCW). At each denoising step, DPMs always generate the reconstructed sample for predicting based on . After each denoising is completed, DCW maps and to the wavelet domain via DWT to obtain and , where . Then, DCW corrects the different frequency components of using Eq. 18. Finally, DCW maps the corrected back to the pixel space via iDWT.
Empirical Validation / Results
Extensive experiments validate DCW's effectiveness, generality, and superiority.
Results on Classic Diffusion Models
DCW improves baseline models (IDDPM, ADM, ADM-IP) across datasets of varying resolutions (CIFAR-10, CelebA, ImageNet, LSUN Bedroom).
Table 2. FID and Recall (Rec) on datasets with different resolutions.
| Model | Dataset | T=20 FID ↓ | T=20 Rec ↑ | T=50 FID ↓ | T=50 Rec ↑ |
|---|---|---|---|---|---|
| IDDPM | CIFAR-10 32 | 13.19 | 0.50 | 5.55 | 0.56 |
| +Ours | CIFAR-10 32 | 7.57 | 0.56 | 4.16 | 0.58 |
| ADM-IP | CelebA 64 | 11.95 | 0.42 | 4.52 | 0.55 |
| +Ours | CelebA 64 | 10.41 | 0.47 | 4.34 | 0.57 |
| ADM | ImageNet 128 | 12.28 | 0.52 | 5.18 | 0.58 |
| +Ours | ImageNet 128 | 10.34 | 0.54 | 4.52 | 0.58 |
| IDDPM | LSUN 256 | 18.69 | 0.27 | 8.42 | 0.41 |
| +Ours | LSUN 256 | 11.03 | 0.36 | 5.24 | 0.45 |
DCW also outperforms recent exposure bias methods (DPM-AE, DPM-AT) when applied to the same baselines.
Table 3. FID ↓ on CIFAR-10 using ADM and DDIM.
| Model | DDIM (10/20/50 steps) | ADM (10/20/50 steps) |
|---|---|---|
| Base | 14.40 / 6.87 / 4.15 | 22.62 / 10.52 / 4.55 |
| Base-AE | 13.98 / 6.76 / 4.10 | - |
| Base-AT | - | 15.88 / 6.60 / 3.34 |
| Base+Ours | 9.36 / 4.64 / 3.33 | 13.01 / 5.59 / 2.95 |
Results on Bias-Corrected Diffusion Models
DCW can be integrated into state-of-the-art models designed to mitigate exposure bias (ADM-ES, DPM-FR), providing further improvements, demonstrating its advancement and complementary nature.
Table56. FID ↓ on CIFAR-10 using different fast samplers (Deterministic Sampling).
| Model | EDM (13/21/35 NFE) | PFGM++ (13/21/35 NFE) |
|---|---|---|
| Base | 10.66 / 5.91 / 3.74 | 12.92 / 6.53 / 3.88 |
| +Ours | 5.67 / 3.37 / 2.41 | 6.98 / 3.83 / 2.64 |
| Base-ES | 6.59 / 3.74 / 2.59 | 8.79 / 4.54 / 2.91 |
| +Ours | 6.13 / 3.57 / 2.50 | 8.00 / 4.41 / 2.84 |
| Base-FR | 4.68 / 2.84 / 2.13 | 6.62 / 3.67 / 2.53 |
| +Ours | 4.57 / 2.79 / 2.12 | 6.18 / 3.46 / 2.48 |
Qualitative Results
Visual comparisons on text-to-image models (FLUX, Qwen-Image) show that DCW significantly mitigates artifacts like over-smoothing and distortion, enhancing aesthetic quality and detail, especially in low-step sampling.
Ablation Study and Analysis
- Effect of Wavelet Domain: Ablation shows that applying correction to both high and low-frequency components (DCW) yields the best results, outperforming pixel-space only (DC) or single-frequency corrections (DH, DL).
- Parameter Sensitivity: DCW is robust to hyperparameters and , with FID showing a convex trend across a wide range of values.
- Computational Overhead: DCW adds negligible latency (~0.1-0.5%), making it highly practical.
Table 7. Batch generation time on a single NVIDIA A6000 GPU.
| Model | Dataset | Baseline Time (s) | DCW Time (s) | Overhead |
|---|---|---|---|---|
| ADM-IP | CelebA 64 | 4.25 | 4.27 | 0.47% |
| ADM | ImageNet 128 | 12.59 | 12.60 | 0.08% |
| IDDPM | LSUN 256 | 15.57 | 15.61 | 0.26% |
Theoretical and Practical Implications
- Theoretical: Provides a formal analysis and proof for a previously underexplored bias (SNR-t) in DPMs, establishing it as a fundamental cause of error accumulation distinct from exposure bias.
- Practical: Introduces a simple, training-free, plug-and-play correction method that broadly enhances the performance of existing DPMs and can be combined with other correction techniques for additive gains. Its negligible computational cost makes it suitable for real-world deployment.
Conclusion
The paper identifies and thoroughly analyzes the SNR-t bias in Diffusion Probabilistic Models, where the SNR of denoising samples deviates from their timestep during inference. Comprehensive empirical and theoretical evidence is provided. To mitigate this bias, the authors propose DCW, a differential correction method in the wavelet domain that aligns with DPMs' coarse-to-fine denoising characteristics. Extensive experiments demonstrate that DCW significantly improves the generation quality of a wide range of DPMs across various datasets and can further enhance state-of-the-art corrected models, all with minimal computational overhead.