MARS: Enabling Autoregressive Models Multi-Token Generation - Summary
Summary (Overview)
- Lightweight Fine-tuning for Multi-Token Generation: MARS (Mask AutoRegreSion) is a method that fine-tunes an instruction-tuned autoregressive (AR) model to optionally generate multiple tokens per forward pass, with no architectural changes, no extra parameters, and using only the original instruction data.
- Strict Superset of AR Models: The resulting model is a strict superset of the original AR model. When generating one token per step (
τ=1.0), it matches or exceeds baseline quality. When confident, it can generate multiple tokens per step for speedup. - Key Insights for Preservation: The method is built on closing three eliminable design gaps (attention pattern, logits alignment, generation order) between AR and block-masked prediction, and introducing an auxiliary SFT loss to preserve the model's AR competence during training, preventing degradation at larger block sizes.
- Practical Speedups and Control: MARS achieves 1.5–1.7× algorithmic throughput while maintaining baseline-level accuracy. With a block-level KV caching strategy, it achieves up to 1.71× wall-clock speedup. A confidence threshold (
τ) provides a real-time, on-the-fly latency–quality knob for deployment.
Introduction and Theoretical Foundation
Autoregressive (AR) Language Models (LMs) generate text one token at a time, incurring uniform computational cost regardless of token predictability. This is inefficient. Existing approaches for multi-token generation, like speculative decoding (requires a separate draft model) or multi-head approaches like Medusa (adds extra prediction heads), modify the deployment stack, adding complexity and parameters.
This paper takes a different angle: Can lightweight fine-tuning alone teach an AR model to optionally generate multiple tokens per forward pass while preserving its original AR behavior? The goal is a model that is a strict superset of the original.
A natural starting point is block masked diffusion, which trains a model to predict multiple masked tokens in parallel. However, prior conversions of AR models to this paradigm result in significant quality degradation. The authors identify four gaps between AR and block diffusion:
- Token Masking: Inherent to multi-token prediction.
- Attention Pattern: Prior work uses bidirectional attention within blocks, breaking AR causality.
- Logits Alignment: Prior work may not use the AR convention of predicting
x_{t+1}from positiont. - Generation Order: Confidence-based diffusion unmasking breaks the strict left-to-right AR order.
MARS is designed to close gaps (2), (3), and (4), leaving only the inherent masking gap (1). This ensures the model remains a fully functional AR model.
Methodology
Training Objective and Mask Design
MARS starts from an AR Supervised Fine-Tuning (SFT) checkpoint. The core idea is to process two copies of a sequence in parallel within a single forward pass using a structured attention mask:
- Clean Stream (
x): Original tokens. Used for standard AR next-token prediction loss (L_AR). - Noisy Stream (
\tilde{x}): Tokens are divided into blocks of sizeB. All tokens in each block are replaced with[MASK]placeholders. The model is trained to predict these masked tokens.
The concatenated input is z = [x; \tilde{x}] of length 2L. The attention mask M ∈ {0, -∞}^{2L×2L} enforces:
- Clean Causal: Clean positions attend causally to previous clean positions.
- Noisy Intra-block Causal: Noisy (masked) positions attend causally only within their own block (to other
[MASK]tokens). - Cross-Stream: A noisy block
kcan attend to all clean tokens from blocks1, ..., k-1.
The training loss combines the masked prediction loss and the AR loss:
where
Preserving Autoregressive Competence: The SFT Loss
Without the SFT loss (L_AR), the training signal that resembles standard AR next-token prediction decays as the block size B increases. The fraction of positions with fully clean context is:
For B=16, this is only 6.25%, leading to degradation.
The combined loss ensures the AR-equivalent signal fraction remains above 50%:
For B=16, this is 53.1%. The SFT loss is crucial for maintaining the model's AR capabilities while it learns masked prediction.
Inference: Sliding Window with Confidence Thresholding
At inference, B [MASK] tokens are appended to the current prefix. The model runs one forward pass (with pure causal attention) to get logits for all B positions.
- Starting from the leftmost position, tokens are accepted strictly left-to-right while the confidence of the top token
max_v p(x_t = v)exceeds a thresholdτ. - At least one token is always accepted (ensuring fallback to AR).
- The
Naccepted tokens join the prefix, andNnew[MASK]tokens are appended, maintaining the window sizeB. - The threshold
τcontrols the speed-quality tradeoff and can be adjusted in real-time during serving.
Block-Level KV Caching for Batch Inference
To achieve wall-clock speedup in batch inference, MARS uses a block-level KV cache:
- Compute the prefix KV cache once per block via a full forward pass.
- Iterate within the block using the cached prefix; each inner step only forwards
Bnew tokens. - Once all samples in the batch have filled the current block, extend the cache with the completed block and advance to the next. Faster samples idle at block boundaries.
Empirical Validation / Results
Setup
- Models: Qwen2.5-0.5B-Instruct and Qwen2.5-7B-Instruct, fine-tuned on Dolci-Instruct-SFT.
- Training: 5 epochs of standard AR SFT, followed by 5 epochs of MARS training.
- Evaluation: Six benchmarks: IFEval, BBH, MMLU-Pro, GPQA, GSM8K, HumanEval. Greedy decoding, max 256 new tokens.
One-Token Mode (τ = 1.0) Preserves/Improves Quality
Table 2: One-token mode (τ=1.0): MARS vs. AR SFT, compute-matched AR SFT (10 epochs), and Block Diffusion.
| Model | IFEval | BBH | MMLU-Pro | GPQA | GSM8K | HumanEval | Avg |
|---|---|---|---|---|---|---|---|
| Qwen2.5-0.5B-Instruct | |||||||
| AR SFT (5 ep) | 48.4 | 26.3 | 11.9 | 17.9 | 32.0 | 35.4 | 28.7 |
| AR SFT (10 ep) | 47.8 | 26.3 | 9.3 | 14.1 | 28.3 | 32.3 | 26.4 |
| Block Diffusion (B=4) | 47.1 | 7.5 | 2.0 | 17.9 | 30.6 | 31.7 | 22.8 |
| MARS-0.5B w/o SFT loss (B=4) | 48.5 | 27.4 | 12.3 | 19.0 | 29.5 | 33.5 | 28.4 |
| MARS-0.5B (B=4) | 51.3 | 26.6 | 12.4 | 19.4 | 32.8 | 40.2 | 30.4 |
| Qwen2.5-7B-Instruct | |||||||
| AR SFT | 67.0 | 54.0 | 43.9 | 27.5 | 68.7 | 78.7 | 56.6 |
| MARS-7B (B=4) | 68.2 | 54.6 | 44.4 | 26.6 | 73.2 | 81.7 | 58.1 |
Key Findings:
- MARS matches or exceeds the AR SFT baseline at both scales (+1.7 avg at 0.5B, +1.5 avg at 7B).
- Extra AR SFT epochs (compute-matched baseline) cause overfitting and degrade performance.
- Block Diffusion, which uses bidirectional attention, collapses on reasoning tasks (BBH, MMLU-Pro).
The SFT Loss is Crucial for Larger Blocks
Table 3: Effect of SFT loss across block sizes (0.5B, τ=1.0).
| Model | IFEval | BBH | MMLU-Pro | GPQA | GSM8K | HumanEval | Avg |
|---|---|---|---|---|---|---|---|
| AR SFT | 48.4 | 26.3 | 11.9 | 17.9 | 32.0 | 35.4 | 28.7 |
| MARS w/o SFT loss | |||||||
| B=4 | 48.5 | 27.4 | 12.3 | 19.0 | 29.5 | 33.5 | 28.4 |
| B=8 | 48.5 | 24.3 | 11.1 | 20.8 | 24.9 | 28.7 | 26.4 |
| B=16 | 42.6 | 21.7 | 10.6 | 16.3 | 21.0 | 20.7 | 22.2 |
| MARS with SFT loss | |||||||
| B=4 | 51.3 | 26.6 | 12.4 | 19.4 | 32.8 | 40.2 | 30.4 |
| B=8 | 49.6 | 26.9 | 12.0 | 19.6 | 32.8 | 37.2 | 29.7 |
| B=16 | 50.7 | 27.0 | 12.1 | 17.9 | 33.8 | 36.6 | 29.7 |
- Without SFT loss, performance degrades sharply as
Bincreases (Avg: 28.4 → 22.2 from B=4 to B=16). - With SFT loss, performance remains stable across block sizes (Avg: 30.4 → 29.7).
Smooth Speed-Quality Tradeoff
Table放置 4: Multi-token mode (τ=0.95). Accuracy with change from τ=1.0 in parentheses.
| Model | IFEval | BBH | MMLU-Pro | GPQA | GSM8K | HumanEval | Avg | Tok/Fwd |
|---|---|---|---|---|---|---|---|---|
| MARS-0.5B B=8 | 44.0 (-5.6) | 28.3 (+1.4) | 11.9 (-0.1) | 19.0 (-0.6) | 31.7 (-1.1) | 36.6 (-0.6) | 28.6 | 1.49 |
| MARS-7B B=4 | 63.0 (-5.2) | 54.3 (-0.3) | 44.2 (-0.2) | 27.7 (+1.1) | 71.0 (-2.2) | 80.5 (-1.2) | 56.8 | 1.68 |
- At
τ=0.95, MARS achieves 1.49-1.68 tokens per forward pass with minimal accuracy loss (e.g., -1.3 avg points at 7B). - The 7B model at
τ=0.95(56.8 avg) still exceeds the AR SFT baseline (56.6 avg).
Wall-Clock Speedup with Block-Level KV Cache
Table 5: Batch inference on GSM8K (256 questions) with Qwen2.5-7B at τ=0.95.
| Method | B_cache | Batch size=4 | Batch size=8 | Batch size=16 |
|---|---|---|---|---|
| tok/s | time(s) | × | ||
| AR (KV cache) | – | 143.4 | 276.2 | 1.00 |
| MARS + block cache | 32 | 241.9 | 161.2 | 1.71 |
| MARS + block cache | 16 | 228.5 | 170.5 | 1.62 |
- Without caching, MARS is slower than AR at larger batch sizes.
- With block-level KV caching, MARS achieves up to 1.71× wall-clock speedup over AR with KV cache.
- The optimal cache granularity (
B_cache) depends on batch size.
Theoretical and Practical Implications
- Theoretical: Demonstrates that causality is compatible with multi-token generation. The quality degradation in prior block-masked approaches was largely due to eliminable design choices (bidirectional attention, misaligned logits, non-left-to-right generation), not an inherent limitation.
- Practical Deployment: Provides a single-model solution that serves as both a high-quality AR model and a faster multi-token model. The confidence threshold (
τ) acts as a real-time "latency–quality knob," allowing serving systems to dynamically increase throughput under high load without model swaps. - Efficiency: Offers a lightweight fine-tuning path to significant inference acceleration (1.5-1.7× throughput, ~1.7× wall-clock speedup) without architectural changes or extra parameters, making it easy to integrate into existing production pipelines.
Conclusion
MARS is a lightweight fine-tuning method that transforms an instruction-tuned AR model into a strict superset capable of optional multi-token generation. By closing three eliminable gaps between AR and block-masked prediction and using a combined SFT+masked loss, it preserves baseline AR quality while enabling 1.5–1.7× throughput gains. With block-level KV caching, these translate to up to 1.71× wall-clock speedup. The method provides a practical, controllable speed-quality frontier for efficient LLM deployment.
Future Work & Limitations:
- Future: Exploring adaptive block sizes, cursor-based cache management, and integration with speculative decoding.
- Limitations: Training compute is ~2× standard SFT due to sequence concatenation. Aggressive thresholds (
τ < 0.7) show quality loss. Block-level KV cache requires batch synchronization at boundaries.