MEDS: Memory-Enhanced Dynamic Reward Shaping for LLM Reinforcement Learning

Summary (Overview)

  • Core Problem: On-policy reinforcement learning (RL) for large language models (LLMs) often suffers from "error collapse," where the policy degenerates into generating highly repetitive erroneous reasoning patterns.
  • Proposed Solution: Introduces MEDS (Memory-Enhanced Dynamic reward Shaping), a framework that penalizes repeated historical error patterns to encourage broader exploration and reduce recurrent mistakes.
  • Key Mechanism: Uses layer-wise logits as lightweight, low-overhead representations of the model's internal reasoning process. These are clustered (using HDBSCAN) to identify and track recurring error patterns across training epochs.
  • Reward Shaping: Dynamically adjusts the reward for a new response based on the size of its assigned error cluster: r~(x,y~)=r(x,y~)min(αlog(Ck+1),β)\tilde{r}(x, \tilde{y}) = r(x, \tilde{y}) - \min(\alpha \log(|C_k| + 1), \beta).
  • Empirical Results: MEDS consistently outperforms baselines (GRPO, DAPO) across five mathematical reasoning benchmarks and three base models, achieving gains of up to 4.13 pass@1 points and 4.37 pass@128 points. It also demonstrably increases behavioral diversity during sampling.

Introduction and Theoretical Foundation

Reinforcement learning has become a key method for improving LLM performance. However, a persistent failure mode is error collapse: as training progresses, the policy often converges to a narrow set of behaviors, repeatedly generating similar erroneous reasoning trajectories. This reduces exploration and entrenches the model in self-reinforcing failure modes.

While entropy regularization encourages randomness under the current policy distribution, it does not explicitly discourage recurrent error patterns that appear across different rollouts and training steps. The vast action space of LLMs makes distribution-level stochasticity insufficient to escape these stable error basins.

Inspired by dynamic reward mechanisms in human learning (where stronger penalties are imposed on recurring mistakes), this paper proposes MEDS. The core idea is to dynamically record historical error patterns and impose incremental penalties on repetitive failure paths. This provides more informative supervision than static rewards, helping the model escape local optima.

Theoretical Analysis: The paper provides a proof (Theorem 1) that penalizing repeated errors improves the expected return. Given two reward signals:

  • μ1=r(x,y)\mu_1 = r(x, y) (standard reward)
  • μ2=r(x,y)λc(y)\mu_2 = r(x, y) - \lambda c(y) (reward with penalty for repetition)

Let q1q_1 and q2q_2 be the corresponding updated policies, and J(q)=Ex,yq[r(x,y)]J(q) = \mathbb{E}_{x,y \sim q}[r(x, y)] be the expected return. The theorem states:

J(q2)J(q1)J(q_2) \geq J(q_1)

where λ\lambda is the penalty coefficient and c(y)c(y) is an indicator function that increases monotonically with the number of past occurrences of error yy. The proof leverages the Gibbs form of the KL-regularized policy update and Chebyshev's rearrangement inequality to show that the covariance between the reward rr and the penalty weight w(y)=exp(ηλc(y))w(y) = \exp(-\eta\lambda c(y)) is non-negative, leading to the inequality.

Methodology

MEDS operates in three stages, as illustrated in Figure 2 of the paper.

1. Preliminaries: The standard RL objective for an LLM policy pθp_\theta is:

maxθExD[Eypθ(x)[r(x,y)]1ηKL(pθ(x)pref(x))]\max_\theta \mathbb{E}_{x \sim D}\left[ \mathbb{E}_{y \sim p_\theta(\cdot|x)}[r(x, y)] - \frac{1}{\eta} \text{KL}(p_\theta(\cdot | x) \| p_{\text{ref}}(\cdot | x)) \right]

where prefp_{\text{ref}} is a reference policy and η\eta controls the regularization strength.

2. Logic Feature Extraction: To implement the indicator function c(y)c(y) efficiently, MEDS reuses the model's internal layer-wise logits. For a response yy, let y=yLty^* = y_{L-t} be the first token of the final answer. The logit for this token at layer nn is lyLt(n)RVl^{(n)}_{y_{L-t}} \in \mathbb{R}^V. The scalar value for the chosen token is denoted l(n)Rl^{*(n)} \in \mathbb{R}.

Since earlier layers model simpler semantics, the feature vector is constructed from the latter half of the Transformer layers:

f(y)=concat(l(n) | n=N2,,N)RN2f(y) = \text{concat}\left( l^{*(n)} \ \middle| \ n = \frac{N}{2}, \dots, N \right) \in \mathbb{R}^{\frac{N}{2}}

This vector serves as a compact proxy for the model's reasoning trajectory.

3. Cluster-based Reward Shaping: For each prompt xx, a memory GxG_x is maintained containing feature vectors of all historical responses. HDBSCAN clustering is applied to GxG_x to obtain clusters C1,C2,,CK,Cnoise\\{C_1, C_2, \dots, C_K, C_{\text{noise}}\\}.

For a new response y~\tilde{y} assigned to cluster CkC_k, the penalty is defined as a monotonic function of the cluster size:

c(y~)=log(Ck+1)c(\tilde{y}) = \log(|C_k| + 1)

The shaped reward is then:

r~(x,y~)=r(x,y~)min(αlog(Ck+1),β),s.t. f(y~)Ck\tilde{r}(x, \tilde{y}) = r(x, \tilde{y}) - \min\left(\alpha \log(|C_k| + 1), \beta\right), \quad \text{s.t.} \ f(\tilde{y}) \in C_k

where α\alpha and β\beta are hyperparameters controlling penalty strength and its upper bound. This directly implements the theoretical penalty on repeated errors.

Empirical Validation / Results

Experimental Setup:

  • Models: Qwen3-1.7B, Qwen2.5-Math-7B, Qwen3-8B.
  • Baselines: Base model, GRPO, DAPO, GRPO with Entropy Advantage.
  • Benchmarks: Five mathematical reasoning datasets (AIME24, AMC23, MATH500, Minerva, OlympiadBench).
  • Evaluation: pass@1 and pass@128.

Main Results

Table 1: Main results on five mathematical reasoning benchmarks. MEDS consistently achieves the best overall performance.

MethodAIME24AMC23MATH500MinervaOlympiadBenchAverage
pass@1pass@128pass@1pass@128pass@1pass@128
Qwen3-1.7B22.5866.6758.5787.5080.8695.20
+GRPO18.2046.6751.5892.5079.5895.60
+DAPO13.4663.3357.6595.0080.2596.60
+GRPO w/ Entropy Adv12.5050.0051.8095.0078.7294.00
+ MEDS23.6570.0065.5195.0084.6697.40
Qwen2.5-Math-7B6.4356.6719.8292.5030.8496.00
+GRPO23.8566.6770.0897.5084.1096.20
+DAPO32.2770.0072.7197.5085.6196.00
+GRPO w/ Entropy Adv29.7463.3370.6497.5084.5096.20
+ MEDS34.3276.6774.3897.5086.3396.00
Qwen3-8B34.5170.0065.7290.0083.6295.40
+GRPO22.4570.0058.7990.0082.3895.40
+DAPO45.4273.3381.3795.0089.1896.60
+GRPO w/ Entropy Adv27.3263.3365.4992.5085.8996.00
+ MEDS45.7876.6782.6297.5092.5198.20
  • Key Findings:
    • MEDS achieves the best average performance across all three model scales.
    • The improvement is most pronounced on the larger Qwen3-8B model, suggesting benefits scale with model capability.
    • On Qwen3-8B, MEDS improves pass@128 from 70.81 to 82.67 on OlympiadBench, a 17% relative gain.
    • Pass@k curves (Figure 4) show MEDS consistently matches or outperforms baselines across all k.

Impact on Exploration Behavior:

  • Diversity Metrics: Using LLM-based annotation, MEDS shows higher Within-Step Diversity (diversity of rollouts at the same step) and Across-Step Diversity (novelty of later rollouts compared to earlier ones) than DAPO.
  • Representation Diversity: The Top-1 Eigen Ratio of the logit covariance matrix is lower for MEDS, indicating a more uniform spread of representations and less concentration on a single error direction.

    Top-1 Eigen Ratio =λ1j=1dλj= \frac{\lambda_1}{\sum_{j=1}^d \lambda_j}, where λi\lambda_i are eigenvalues of the logit covariance matrix.

Logits Reflect Reasoning Patterns:

  • Qualitative Case Study (Figure 6): Responses with similar underlying reasoning patterns (even if producing different final answers) exhibit correlated layer-wise logit trajectories, especially in later layers.
  • Quantitative Validation: Logit-based clustering shows a 61.2% correlation with error categories annotated by Claude-Haiku-4.5, confirming logits serve as an acceptable proxy for reasoning structure.

Clustering Quality Matters:

  • Table 2: Comparison of different clustering feature constructions on Qwen2.5-Math-7B. Using the last 14 layers (MEDS-14) yields the best performance.
MethodAIME24AMCMATH500MinervaOlympiadBenchAverage
pass@1pass@128pass@1pass@128pass@1pass@128
DAPO32.2770.0072.7197.5085.6196.00
MEDS-single cluster29.7070.0073.60100.0086.6096.80
MEDS-28-diff31.8870.0072.3497.5085.8096.00
MEDS-14-diff32.4073.3371.64100.0086.4096.60
MEDS-2835.6370.0073.1397.5086.2096.40
MEDS-1434.3276.6774.3897.5086.3396.00
  • Key Findings:
    • Any variant using clustering information outperforms the DAPO baseline.
    • The single cluster degenerate baseline performs worse, highlighting the importance of meaningful clustering.
    • The ranking of downstream performance closely mirrors the ranking of clustering consistency with LLM annotations, showing that better clustering quality leads to larger performance gains.

Theoretical and Practical Implications

Theoretical Implications:

  • Provides a formal justification for incorporating historical error information into reward shaping, proving it can non-decrease the expected return.
  • Establishes layer-wise logits as a valid, efficient proxy for capturing the model's internal reasoning state, connecting representation-level analysis with RL optimization.

Practical Implications:

  • Efficient Exploration: MEDS offers a low-overhead method (reusing existing logits) to significantly improve exploration diversity and mitigate error collapse in LLM RL.
  • Reward Design: Demonstrates the value of dynamic, memory-aware reward functions over static ones, opening a new direction for RL reward engineering.
  • Model Analysis: The success of logit-based clustering validates the use of internal model representations for understanding and steering model behavior beyond surface-level text.

Conclusion

MEDS addresses the critical problem of error collapse in LLM reinforcement learning by introducing a memory-enhanced dynamic reward shaping framework. Key takeaways:

  1. Core Innovation: MEDS is the first method to explicitly incorporate historical error patterns into reward modeling, penalizing repetitive failures to encourage broader exploration.
  2. Efficient Implementation: By leveraging layer-wise logits as lightweight reasoning representations, it achieves significant gains with minimal computational overhead.
  3. Consistent Improvements: Empirical results across multiple models and benchmarks show MEDS consistently improves both pass@1 and pass@128 performance and increases behavioral diversity.
  4. Validated Mechanism: Theoretical analysis and empirical correlations confirm that logit-based clustering effectively captures reasoning patterns and that better clustering leads to better downstream performance.

Limitations and Future Work: The current logit aggregation strategies are relatively simple. Future work could explore more sophisticated feature extraction methods and apply the MEDS framework to other domains beyond mathematical reasoning.

Code: Available at https://github.com/Linxi000/MEDS

Related papers