Summary of "Beyond SFT-to-RL: Pre-alignment via Black-Box On-Policy Distillation for Multimodal RL"

Summary (Overview)

  • Core Problem: The standard two-stage post-training pipeline for Large Multimodal Models (LMMs)—Supervised Fine-Tuning (SFT) followed by Reinforcement Learning with Verifiable Rewards (RLVR)—introduces distributional drift. SFT causes the model to deviate from both its original capabilities and the target supervision distribution, a problem amplified in multimodal reasoning where perception and reasoning errors compound.
  • Proposed Solution: PRISM (Pre-alignment via black-box on-policy dIS tillation for Multimodal reinforcement learning), a three-stage pipeline that inserts an explicit distribution-alignment stage between SFT and RLVR. This stage uses adversarial on-policy distillation with a Mixture-of-Experts (MoE) discriminator to correct drift.
  • Key Innovation: The alignment is formulated as a logit-free, adversarial game between the policy and an MoE discriminator with dedicated perception and reasoning experts, providing disentangled corrective signals without needing teacher model logits.
  • High-Quality Data: The authors curate a 113K high-fidelity multimodal reasoning corpus from Gemini 3 Flash, featuring dense visual grounding and step-by-step reasoning on hard problems, supplementing 1.26M public demonstrations for SFT.
  • Empirical Results: Experiments on Qwen3-VL-4B/8B show PRISM consistently improves downstream RLVR performance across multiple algorithms (GRPO, DAPO, GSPO). PRISM+GRPO outperforms the SFT→RLVR baseline by +4.4 and +6.0 average accuracy points on the 4B and 8B models, respectively.

Introduction and Theoretical Foundation

The prevailing paradigm for improving LMMs involves Supervised Fine-Tuning (SFT) on curated demonstrations followed by Reinforcement Learning with Verifiable Rewards (RLVR). While SFT bootstraps capabilities and RLVR refines performance, recent studies reveal a counterintuitive issue: SFT can introduce distributional drift. The model neither perfectly matches the demonstration policy nor retains its original favorable distribution. This drift is especially costly for stronger base models and is heterogeneous in multimodal settings—visual grounding and logical reasoning degrade in qualitatively different ways, compounding errors during RL.

The paper posits that repairing this drift before RL is crucial. It builds upon On-Policy Distillation (OPD), which mitigates exposure bias by training a model on its own generations rather than static teacher-forced targets. PRISM repositions OPD as a standalone intermediate alignment stage, extending the SFT→RL recipe to SFT → Alignment → RLVR.

Methodology

PRISM is a three-stage pipeline: 1) Cold-Start SFT, 2) Distribution Alignment via On-Policy Distillation, and 3) RLVR.

Stage 1: Cold-Start SFT

  • Objective: Provide an initial multimodal reasoning policy.
  • Data: Combines a curated 113K corpus (107K for SFT) from Gemini 3 Flash with 1.26M public demonstrations. The curated data targets hard problems and includes detailed visual descriptions and reasoning traces.
  • Process: Standard full-parameter fine-tuning (vision tower frozen) for 1 epoch.

Stage 2: Distribution Alignment via On-Policy Distillation This core stage repairs SFT-induced drift. It is formulated as a minimax game between the policy (GG) and an MoE Discriminator.

  • MoE Discriminator Design: Comprises two experts providing disentangled feedback:
    • Perception Expert (DvD_v): Evaluates the visual description (cc) for grounding fidelity.
    • Reasoning Expert (DrD_r): Evaluates the reasoning trace (tt) for logical consistency.
    • The combined discriminator reward for a response yy (with components cc, tt) given input xx is: r(x,y)=αDv(x,c)+(1α)Dr(x,t)(Eq. 1)r(x, y) = \alpha \cdot D_v(x, c) + (1 - \alpha) \cdot D_r(x, t) \quad \text{(Eq. 1)} where α\alpha controls the trade-off (set to 0.5).
  • Adversarial Training:
    • Discriminator Loss: Trained via Bradley-Terry loss to distinguish supervision responses (y+y^+) from policy rollouts (yy^-): LDk=E(x,y+,y)T[logσ(Dk(x,yk+)Dk(x,yk))],k{v,r}(Eq. 2)\mathcal{L}_{D_k} = -\mathbb{E}_{(x,y^+,y^-)\sim T} \left[ \log \sigma\left( D_k(x, y^+_k) - D_k(x, y^-_k) \right) \right], \quad k \in \{v, r\} \quad \text{(Eq. 2)}
    • Policy Update: For a prompt xx, sample NN responses {yi}i=1N\{y^-_i\}_{i=1}^N from the current policy. Compute normalized advantages within the group: Ai=r(x,yi)mean({r(x,yj)})std({r(x,yj)})(Eq. 3)A_i = \frac{r(x, y^-_i) - \text{mean}(\{r(x, y^-_j)\})}{\text{std}(\{r(x, y^-_j)\})} \quad \text{(Eq. 3)} The policy is then updated via GRPO to maximize these advantages. KL regularization is removed to allow free distributional shift.
  • Initialization: The policy is initialized from the SFT checkpoint. The MoE discriminator experts are warm-started on preference pairs from their respective components (visual descriptions, reasoning traces).

Stage\int 3: Reinforcement Learning with Verifiable Rewards

  • Start Point: Policy checkpoint from the alignment stage.
  • Data: A difficulty-filtered subset (~2K samples) from the reserved 6K curated data, where the aligned policy's pass rate is between 0.2 and 0.8.
  • Reward: Switches from the learned MoE reward to a deterministic verifiable reward: rv(x,y)=racc(x,y)+rfmt(x,y)(Eq. 5)r_v(x, y) = r_{\text{acc}}(x, y) + r_{\text{fmt}}(x, y) \quad \text{(Eq. 5)}
  • Optimization: Standard outcome-based RLVR (compatible with GRPO, DAPO, GSPO).

Empirical Validation / Results

Experiments are conducted on Qwen3-VL-4B and 8B across mathematical reasoning (MathVista, MathVerse, MathVision, WeMath) and general multimodal understanding benchmarks (MMMU, MMMU-Pro, HallusionBench).

Main Results

Table 1: Main results on mathematical reasoning and general multimodal benchmarks (Accuracy %).

MethodMathVistaMathVerseMathVisionWeMathMMMUMMMU-ProHallusion BenchAvg.
Qwen3-VL-4B
Instruct74.959.036.570.763.645.168.259.7
+ SFT71.558.431.970.653.642.869.156.8
+ GRPO75.764.535.577.860.147.372.061.8
PRISM71.059.530.667.556.342.872.657.2
PRISM + GRPO77.968.645.482.964.149.774.866.2
Qwen3-VL-8B
Instruct76.062.443.771.765.652.371.663.3
+ SFT70.260.432.673.456.342.971.258.1
+ GRPO75.966.937.179.762.648.871.963.3
PRISM71.462.237.173.158.443.469.559.3
PRISM + GRPO78.371.352.086.466.653.377.269.3
  • PRISM consistently improves RLVR: PRISM+GRPO beats SFT→GRPO by +4.4 avg (4B) and +6.0 avg (8B). Gains are consistent across DAPO and GSPO.
  • Alignment corrects distribution, not immediate accuracy: The PRISM checkpoint (post-alignment, pre-RLVR) has accuracy similar to SFT, confirming its role is distributional correction.
  • SFT drift is more severe for stronger models: The 8B model suffers a larger drop from SFT, and standard RLVR barely recovers the original performance. PRISM+GRPO exceeds the base Instruct model by over 5 points.

Ablation Study

Table 2: Ablation study results (Qwen3-VL-4B with GRPO).

SettingMathVistaMathVerseMathVisionWeMathMMMUMMMU-ProHallusion BenchAvg.
PRISM (full)77.968.645.482.964.149.774.866.2
Discriminator Design
Dense 4B disc.74.663.741.876.961.347.174.062.8
Text-only disc.74.059.542.876.862.748.571.662.3
Pipeline Stages
w/o Alignment75.764.535.577.860.147.372.061.8
w/o SFT62.447.625.955.751.436.566.149.4
SFT Data Scale
SFT-107K72.367.043.176.960.649.068.362.5
  • MoE Discriminator is crucial: Replacing it with a single dense discriminator causes a -3.4 avg drop. The MoE design provides sharper, disentangled signals.
  • Three-stage pipeline is necessary: Removing alignment reverts to the inferior baseline (-4.4 avg). Removing SFT causes catastrophic failure (-16.8 avg), as the initial capability gap is too large for adversarial training.
  • Vision-language discriminator is needed: A text-only discriminator leads to "parrot alignment" and degradation, especially on perception-heavy tasks.
  • SFT data scale matters: Using only the curated 107K samples for SFT results in weaker initialization and -3.7 avg performance.

Analysis

  • Training Dynamics: Analysis of the reward gap Dk(x,yk+)Dk(x,yk)D_k(x, y^+_k) - D_k(x, y^-_k) shows the perception expert converges quickly, while the reasoning expert converges more gradually with oscillation, reflecting their distinct correction natures. Both stabilize after ~500 steps.
  • Structural Proxies of Distribution Alignment: Analyzing proxies like the number of reasoning steps and descriptive items in captions shows:
    • The base model deviates from supervision.
    • SFT reduces but does not eliminate the mismatch.
    • The alignment stage substantially aligns the policy distribution with supervision along both dimensions.
    • These alignment gains persist through subsequent RLVR.

Theoretical and Practical Implications

  • Theoretical: PRISM formally addresses the heterogeneous distributional drift problem in multimodal post-training. It demonstrates that on-policy distillation can be effectively decoupled from RL and serve as a standalone alignment mechanism, especially when enhanced with task-aware, disentangled reward signals (MoE).
  • Practical: The pipeline provides a reliable method to improve RLVR outcomes for LMMs, making RL optimization more stable and effective by starting from a better-aligned policy. The curated high-quality dataset and the structured three-part output format (caption, reasoning, answer) offer a valuable resource and template for multimodal reasoning training. The framework is agnostic to the specific RL algorithm, enhancing its general applicability.

Conclusion

PRISM introduces a novel three-stage post-training pipeline that mitigates SFT-induced distributional drift via an explicit alignment stage based on black-box adversarial on-policy distillation. Key innovations include:

  • An MoE discriminator with dedicated perception and reasoning experts for disentangled corrective feedback.
  • A logit-free formulation that relies only on response samples, removing dependency on teacher model internals.
  • A high-quality, densely grounded multimodal reasoning corpus for supervision.

Extensive experiments validate that PRISM consistently improves downstream RLVR performance across model scales, benchmarks, and RL algorithms. Analysis confirms that the alignment stage successfully narrows the distributional gap, providing a stronger initialization for RL. Future work may focus on reducing the training overhead, extending the framework to tasks without natural output decomposition, and developing better alignment metrics.