# Beyond SFT-to-RL: Pre-alignment via Black-Box On-Policy Distillation for Multimodal RL

> PRISM introduces a three-stage pipeline that uses adversarial on-policy distillation to correct SFT-induced distributional drift, improving downstream RL performance for multimodal models.

- **Source:** [arXiv](https://arxiv.org/abs/2604.28123)
- **Published:** 2026-05-07
- **Permalink:** https://picx.dev/p/jKiYlA
- **Whiteboard:** https://picx.dev/p/jKiYlA/image

## Summary

# Summary of "Beyond SFT-to-RL: Pre-alignment via Black-Box On-Policy Distillation for Multimodal RL"

## Summary (Overview)
*   **Core Problem:** The standard two-stage post-training pipeline for Large Multimodal Models (LMMs)—Supervised Fine-Tuning (SFT) followed by Reinforcement Learning with Verifiable Rewards (RLVR)—introduces **distributional drift**. SFT causes the model to deviate from both its original capabilities and the target supervision distribution, a problem amplified in multimodal reasoning where perception and reasoning errors compound.
*   **Proposed Solution:** **PRISM** (Pre-alignment via black-box on-policy dIS tillation for Multimodal reinforcement learning), a **three-stage pipeline** that inserts an explicit distribution-alignment stage between SFT and RLVR. This stage uses **adversarial on-policy distillation** with a **Mixture-of-Experts (MoE) discriminator** to correct drift.
*   **Key Innovation:** The alignment is formulated as a **logit-free, adversarial game** between the policy and an MoE discriminator with dedicated **perception** and **reasoning experts**, providing disentangled corrective signals without needing teacher model logits.
*   **High-Quality Data:** The authors curate a **113K high-fidelity multimodal reasoning corpus** from Gemini 3 Flash, featuring dense visual grounding and step-by-step reasoning on hard problems, supplementing 1.26M public demonstrations for SFT.
*   **Empirical Results:** Experiments on Qwen3-VL-4B/8B show PRISM **consistently improves downstream RLVR performance** across multiple algorithms (GRPO, DAPO, GSPO). PRISM+GRPO outperforms the SFT→RLVR baseline by **+4.4** and **+6.0** average accuracy points on the 4B and 8B models, respectively.

## Introduction and Theoretical Foundation
The prevailing paradigm for improving LMMs involves **Supervised Fine-Tuning (SFT)** on curated demonstrations followed by **Reinforcement Learning with Verifiable Rewards (RLVR)**. While SFT bootstraps capabilities and RLVR refines performance, recent studies reveal a counterintuitive issue: SFT can introduce **distributional drift**. The model neither perfectly matches the demonstration policy nor retains its original favorable distribution. This drift is especially costly for stronger base models and is **heterogeneous in multimodal settings**—visual grounding and logical reasoning degrade in qualitatively different ways, compounding errors during RL.

The paper posits that repairing this drift *before* RL is crucial. It builds upon **On-Policy Distillation (OPD)**, which mitigates exposure bias by training a model on its own generations rather than static teacher-forced targets. PRISM repositions OPD as a **standalone intermediate alignment stage**, extending the SFT→RL recipe to **SFT → Alignment → RLVR**.

## Methodology
PRISM is a three-stage pipeline: **1) Cold-Start SFT, 2) Distribution Alignment via On-Policy Distillation, and 3) RLVR**.

**Stage 1: Cold-Start SFT**
*   **Objective:** Provide an initial multimodal reasoning policy.
*   **Data:** Combines a curated **113K corpus** (107K for SFT) from Gemini 3 Flash with **1.26M public demonstrations**. The curated data targets hard problems and includes detailed visual descriptions and reasoning traces.
*   **Process:** Standard full-parameter fine-tuning (vision tower frozen) for 1 epoch.

**Stage 2: Distribution Alignment via On-Policy Distillation**
This core stage repairs SFT-induced drift. It is formulated as a **minimax game** between the policy ($G$) and an **MoE Discriminator**.

*   **MoE Discriminator Design:** Comprises two experts providing disentangled feedback:
    *   **Perception Expert ($D_v$):** Evaluates the visual description ($c$) for grounding fidelity.
    *   **Reasoning Expert ($D_r$):** Evaluates the reasoning trace ($t$) for logical consistency.
    *   The combined discriminator reward for a response $y$ (with components $c$, $t$) given input $x$ is:
        $$r(x, y) = \alpha \cdot D_v(x, c) + (1 - \alpha) \cdot D_r(x, t) \quad \text{(Eq. 1)}$$
        where $\alpha$ controls the trade-off (set to 0.5).
*   **Adversarial Training:**
    *   **Discriminator Loss:** Trained via Bradley-Terry loss to distinguish supervision responses ($y^+$) from policy rollouts ($y^-$):
        $$\mathcal{L}_{D_k} = -\mathbb{E}_{(x,y^+,y^-)\sim T} \left[ \log \sigma\left( D_k(x, y^+_k) - D_k(x, y^-_k) \right) \right], \quad k \in \{v, r\} \quad \text{(Eq. 2)}$$
    *   **Policy Update:** For a prompt $x$, sample $N$ responses $\{y^-_i\}_{i=1}^N$ from the current policy. Compute normalized advantages within the group:
        $$A_i = \frac{r(x, y^-_i) - \text{mean}(\{r(x, y^-_j)\})}{\text{std}(\{r(x, y^-_j)\})} \quad \text{(Eq. 3)}$$
        The policy is then updated via GRPO to maximize these advantages. **KL regularization is removed** to allow free distributional shift.
*   **Initialization:** The policy is initialized from the SFT checkpoint. The MoE discriminator experts are warm-started on preference pairs from their respective components (visual descriptions, reasoning traces).

**Stage\int 3: Reinforcement Learning with Verifiable Rewards**
*   **Start Point:** Policy checkpoint from the alignment stage.
*   **Data:** A **difficulty-filtered subset** (~2K samples) from the reserved 6K curated data, where the aligned policy's pass rate is between 0.2 and 0.8.
*   **Reward:** Switches from the learned MoE reward to a **deterministic verifiable reward**:
    $$r_v(x, y) = r_{\text{acc}}(x, y) + r_{\text{fmt}}(x, y) \quad \text{(Eq. 5)}$$
*   **Optimization:** Standard outcome-based RLVR (compatible with GRPO, DAPO, GSPO).

## Empirical Validation / Results
Experiments are conducted on **Qwen3-VL-4B and 8B** across mathematical reasoning (MathVista, MathVerse, MathVision, WeMath) and general multimodal understanding benchmarks (MMMU, MMMU-Pro, HallusionBench).

### Main Results
**Table 1: Main results on mathematical reasoning and general multimodal benchmarks (Accuracy %).**

| Method | MathVista | MathVerse | MathVision | WeMath | MMMU | MMMU-Pro | Hallusion Bench | **Avg.** |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| **Qwen3-VL-4B** | | | | | | | | |
| Instruct | 74.9 | 59.0 | 36.5 | 70.7 | 63.6 | 45.1 | 68.2 | **59.7** |
| + SFT | 71.5 | 58.4 | 31.9 | 70.6 | 53.6 | 42.8 | 69.1 | 56.8 |
| + GRPO | 75.7 | 64.5 | 35.5 | 77.8 | 60.1 | 47.3 | 72.0 | 61.8 |
| **PRISM** | 71.0 | 59.5 | 30.6 | 67.5 | 56.3 | 42.8 | 72.6 | 57.2 |
| **PRISM + GRPO** | **77.9** | **68.6** | **45.4** | **82.9** | **64.1** | **49.7** | **74.8** | **66.2** |
| **Qwen3-VL-8B** | | | | | | | | |
| Instruct | 76.0 | 62.4 | 43.7 | 71.7 | 65.6 | 52.3 | 71.6 | **63.3** |
| + SFT | 70.2 | 60.4 | 32.6 | 73.4 | 56.3 | 42.9 | 71.2 | 58.1 |
| + GRPO | 75.9 | 66.9 | 37.1 | 79.7 | 62.6 | 48.8 | 71.9 | 63.3 |
| **PRISM** | 71.4 | 62.2 | 37.1 | 73.1 | 58.4 | 43.4 | 69.5 | 59.3 |
| **PRISM + GRPO** | **78.3** | **71.3** | **52.0** | **86.4** | **66.6** | **53.3** | **77.2** | **69.3** |

*   **PRISM consistently improves RLVR:** PRISM+GRPO beats SFT→GRPO by **+4.4 avg (4B)** and **+6.0 avg (8B)**. Gains are consistent across DAPO and GSPO.
*   **Alignment corrects distribution, not immediate accuracy:** The PRISM checkpoint (post-alignment, pre-RLVR) has accuracy similar to SFT, confirming its role is distributional correction.
*   **SFT drift is more severe for stronger models:** The 8B model suffers a larger drop from SFT, and standard RLVR barely recovers the original performance. PRISM+GRPO exceeds the base Instruct model by over 5 points.

### Ablation Study
**Table 2: Ablation study results (Qwen3-VL-4B with GRPO).**

| Setting | MathVista | MathVerse | MathVision | WeMath | MMMU | MMMU-Pro | Hallusion Bench | **Avg.** |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| **PRISM (full)** | 77.9 | 68.6 | 45.4 | 82.9 | 64.1 | 49.7 | 74.8 | **66.2** |
| *Discriminator Design* | | | | | | | | |
| Dense 4B disc. | 74.6 | 63.7 | 41.8 | 76.9 | 61.3 | 47.1 | 74.0 | 62.8 |
| Text-only disc. | 74.0 | 59.5 | 42.8 | 76.8 | 62.7 | 48.5 | 71.6 | 62.3 |
| *Pipeline Stages* | | | | | | | | |
| w/o Alignment | 75.7 | 64.5 | 35.5 | 77.8 | 60.1 | 47.3 | 72.0 | 61.8 |
| w/o SFT | 62.4 | 47.6 | 25.9 | 55.7 | 51.4 | 36.5 | 66.1 | 49.4 |
| *SFT Data Scale* | | | | | | | | |
| SFT-107K | 72.3 | 67.0 | 43.1 | 76.9 | 60.6 | 49.0 | 68.3 | 62.5 |

*   **MoE Discriminator is crucial:** Replacing it with a single dense discriminator causes a **-3.4 avg drop**. The MoE design provides sharper, disentangled signals.
*   **Three-stage pipeline is necessary:** Removing alignment reverts to the inferior baseline (**-4.4 avg**). Removing SFT causes catastrophic failure (**-16.8 avg**), as the initial capability gap is too large for adversarial training.
*   **Vision-language discriminator is needed:** A text-only discriminator leads to "parrot alignment" and degradation, especially on perception-heavy tasks.
*   **SFT data scale matters:** Using only the curated 107K samples for SFT results in weaker initialization and **-3.7 avg** performance.

### Analysis
*   **Training Dynamics:** Analysis of the reward gap $D_k(x, y^+_k) - D_k(x, y^-_k)$ shows the perception expert converges quickly, while the reasoning expert converges more gradually with oscillation, reflecting their distinct correction natures. Both stabilize after ~500 steps.
*   **Structural Proxies of Distribution Alignment:** Analyzing proxies like the number of reasoning steps and descriptive items in captions shows:
    *   The base model deviates from supervision.
    *   SFT reduces but does not eliminate the mismatch.
    *   The alignment stage substantially aligns the policy distribution with supervision along both dimensions.
    *   These alignment gains persist through subsequent RLVR.

## Theoretical and Practical Implications
*   **Theoretical:** PRISM formally addresses the **heterogeneous distributional drift problem** in multimodal post-training. It demonstrates that **on-policy distillation can be effectively decoupled from RL** and serve as a standalone alignment mechanism, especially when enhanced with task-aware, disentangled reward signals (MoE).
*   **Practical:** The pipeline provides a **reliable method to improve RLVR outcomes** for LMMs, making RL optimization more stable and effective by starting from a better-aligned policy. The curated high-quality dataset and the structured three-part output format (caption, reasoning, answer) offer a valuable resource and template for multimodal reasoning training. The framework is **agnostic to the specific RL algorithm**, enhancing its general applicability.

## Conclusion
PRISM introduces a novel three-stage post-training pipeline that mitigates SFT-induced distributional drift via an explicit alignment stage based on **black-box adversarial on-policy distillation**. Key innovations include:
*   An **MoE discriminator** with dedicated perception and reasoning experts for disentangled corrective feedback.
*   A **logit-free formulation** that relies only on response samples, removing dependency on teacher model internals.
*   A **high-quality, densely grounded multimodal reasoning corpus** for supervision.

Extensive experiments validate that PRISM **consistently improves downstream RLVR performance** across model scales, benchmarks, and RL algorithms. Analysis confirms that the alignment stage successfully narrows the distributional gap, providing a stronger initialization for RL. Future work may focus on reducing the training overhead, extending the framework to tasks without natural output decomposition, and developing better alignment metrics.

---

_Markdown view of https://picx.dev/p/jKiYlA, served by PicX — AI-generated visual whiteboard summaries of research papers._
