# The Past Is Not Past: Memory-Enhanced Dynamic Reward Shaping

> MEDS uses memory-enhanced dynamic reward shaping to penalize recurrent error patterns, boosting LLM reasoning performance and diversity by up to 4.13 points on benchmarks.

- **Source:** [arXiv](https://arxiv.org/abs/2604.11297)
- **Published:** 2026-04-15
- **Permalink:** https://picx.dev/p/GIpUP0
- **Whiteboard:** https://picx.dev/p/GIpUP0/image

## Summary

# MEDS: Memory-Enhanced Dynamic Reward Shaping for LLM Reinforcement Learning

## Summary (Overview)
*   **Core Problem**: On-policy reinforcement learning (RL) for large language models (LLMs) often suffers from "error collapse," where the policy degenerates into generating highly repetitive erroneous reasoning patterns.
*   **Proposed Solution**: Introduces **MEDS (Memory-Enhanced Dynamic reward Shaping)**, a framework that penalizes repeated historical error patterns to encourage broader exploration and reduce recurrent mistakes.
*   **Key Mechanism**: Uses **layer-wise logits** as lightweight, low-overhead representations of the model's internal reasoning process. These are clustered (using HDBSCAN) to identify and track recurring error patterns across training epochs.
*   **Reward Shaping**: Dynamically adjusts the reward for a new response based on the size of its assigned error cluster: $\tilde{r}(x, \tilde{y}) = r(x, \tilde{y}) - \min(\alpha \log(|C_k| + 1), \beta)$.
*   **Empirical Results**: MEDS consistently outperforms baselines (GRPO, DAPO) across five mathematical reasoning benchmarks and three base models, achieving gains of up to **4.13 pass@1 points** and **4.37 pass@128 points**. It also demonstrably increases behavioral diversity during sampling.

## Introduction and Theoretical Foundation

Reinforcement learning has become a key method for improving LLM performance. However, a persistent failure mode is **error collapse**: as training progresses, the policy often converges to a narrow set of behaviors, repeatedly generating similar erroneous reasoning trajectories. This reduces exploration and entrenches the model in self-reinforcing failure modes.

While entropy regularization encourages randomness under the *current* policy distribution, it does not explicitly discourage *recurrent* error patterns that appear across different rollouts and training steps. The vast action space of LLMs makes distribution-level stochasticity insufficient to escape these stable error basins.

Inspired by dynamic reward mechanisms in human learning (where stronger penalties are imposed on recurring mistakes), this paper proposes **MEDS**. The core idea is to **dynamically record historical error patterns and impose incremental penalties on repetitive failure paths**. This provides more informative supervision than static rewards, helping the model escape local optima.

**Theoretical Analysis**: The paper provides a proof (Theorem 1) that penalizing repeated errors improves the expected return. Given two reward signals:
*   $\mu_1 = r(x, y)$ (standard reward)
*   $\mu_2 = r(x, y) - \lambda c(y)$ (reward with penalty for repetition)

Let $q_1$ and $q_2$ be the corresponding updated policies, and $J(q) = \mathbb{E}_{x,y \sim q}[r(x, y)]$ be the expected return. The theorem states:

> $J(q_2) \geq J(q_1)$

where $\lambda$ is the penalty coefficient and $c(y)$ is an indicator function that increases monotonically with the number of past occurrences of error $y$. The proof leverages the Gibbs form of the KL-regularized policy update and Chebyshev's rearrangement inequality to show that the covariance between the reward $r$ and the penalty weight $w(y) = \exp(-\eta\lambda c(y))$ is non-negative, leading to the inequality.

## Methodology

MEDS operates in three stages, as illustrated in Figure 2 of the paper.

**1. Preliminaries**: The standard RL objective for an LLM policy $p_\theta$ is:
$$
\max_\theta \mathbb{E}_{x \sim D}\left[ \mathbb{E}_{y \sim p_\theta(\cdot|x)}[r(x, y)] - \frac{1}{\eta} \text{KL}(p_\theta(\cdot | x) \| p_{\text{ref}}(\cdot | x)) \right]
$$
where $p_{\text{ref}}$ is a reference policy and $\eta$ controls the regularization strength.

**2. Logic Feature Extraction**: To implement the indicator function $c(y)$ efficiently, MEDS reuses the model's internal **layer-wise logits**. For a response $y$, let $y^* = y_{L-t}$ be the first token of the final answer. The logit for this token at layer $n$ is $l^{(n)}_{y_{L-t}} \in \mathbb{R}^V$. The scalar value for the chosen token is denoted $l^{*(n)} \in \mathbb{R}$.

Since earlier layers model simpler semantics, the feature vector is constructed from the latter half of the Transformer layers:
$$
f(y) = \text{concat}\left( l^{*(n)} \ \middle| \ n = \frac{N}{2}, \dots, N \right) \in \mathbb{R}^{\frac{N}{2}}
$$
This vector serves as a compact proxy for the model's reasoning trajectory.

**3. Cluster-based Reward Shaping**: For each prompt $x$, a memory $G_x$ is maintained containing feature vectors of all historical responses. HDBSCAN clustering is applied to $G_x$ to obtain clusters $\\{C_1, C_2, \dots, C_K, C_{\text{noise}}\\}$.

For a new response $\tilde{y}$ assigned to cluster $C_k$, the penalty is defined as a monotonic function of the cluster size:
$$
c(\tilde{y}) = \log(|C_k| + 1)
$$
The shaped reward is then:
$$
\tilde{r}(x, \tilde{y}) = r(x, \tilde{y}) - \min\left(\alpha \log(|C_k| + 1), \beta\right), \quad \text{s.t.} \ f(\tilde{y}) \in C_k
$$
where $\alpha$ and $\beta$ are hyperparameters controlling penalty strength and its upper bound. This directly implements the theoretical penalty on repeated errors.

## Empirical Validation / Results

**Experimental Setup**:
*   **Models**: Qwen3-1.7B, Qwen2.5-Math-7B, Qwen3-8B.
*   **Baselines**: Base model, GRPO, DAPO, GRPO with Entropy Advantage.
*   **Benchmarks**: Five mathematical reasoning datasets (AIME24, AMC23, MATH500, Minerva, OlympiadBench).
*   **Evaluation**: pass@1 and pass@128.

### Main Results

**Table 1: Main results on five mathematical reasoning benchmarks.** MEDS consistently achieves the best overall performance.

| Method | AIME24 | AMC23 | MATH500 | Minerva | OlympiadBench | **Average** |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 |
| **Qwen3-1.7B** | 22.58 | 66.67 | 58.57 | 87.50 | 80.86 | 95.20 | 37.82 | 63.6 | 41.97 | 64.15 | 48.36 | 75.42 |
| +GRPO | 18.20 | 46.67 | 51.58 | 92.50 | 79.58 | 95.60 | 38.62 | 65.81 | 40.19 | 70.37 | 45.63 | 74.19 |
| +DAPO | 13.46 | 63.33 | 57.65 | 95.00 | 80.25 | 96.60 | 39.45 | 69.49 | 42.20 | 77.04 | 46.6 | 80.29 |
| +GRPO w/ Entropy Adv | 12.50 | 50.00 | 51.80 | 95.00 | 78.72 | 94.00 | 38.91 | 62.87 | 38.47 | 66.96 | 44.08 | 73.77 |
| **+ MEDS** | **23.65** | **70.00** | **65.51** | **95.00** | **84.66** | **97.40** | **41.77** | **68.75** | **46.86** | **77.63** | **52.49** | **81.76** |
| **Qwen2.5-Math-7B** | 6.43 | 56.67 | 19.82 | 92.50 | 30.84 | 96.00 | 6.62 | 63.24 | 9.27 | 70.07 | 14.60 | 75.70 |
| +GRPO | 23.85 | 66.67 | 70.08 | 97.50 | 84.10 | 96.20 | 39.07 | 73.16 | 42.57 | 73.63 | 51.93 | 81.43 |
| +DAPO | 32.27 | 70.00 | 72.71 | 97.50 | 85.61 | 96.00 | 43.01 | 75.74 | 44.43 | 75.00 | 55.61 | 82.85 |
| +GRPO w/ Entropy Adv | 29.74 | 63.33 | 70.64 | 97.50 | 84.50 | 96.20 | 40.68 | 70.96 | 42.57 | 73.63 | 53.63 | 80.32 |
| **+ MEDS** | **34.32** | **76.67** | **74.38** | **97.50** | **86.33** | **96.00** | **42.51** | **74.26** | **44.80** | **75.56** | **56.47** | **84.00** |
| **Qwen3-8B** | 34.51 | 70.00 | 65.72 | 90.00 | 83.62 | 95.40 | 40.90 | 61.03 | 44.69 | 61.63 | 53.89 | 75.61 |
| +GRPO | 22.45 | 70.00 | 58.79 | 90.00 | 82.38 | 95.40 | 44.07 | 64.71 | 44.22 | 69.78 | 50.38 | 77.98 |
| +DAPO | 45.42 | 73.33 | 81.37 | 95.00 | 89.18 | 96.60 | 46.82 | 65.44 | 52.77 | 70.81 | 63.11 | 80.24 |
| +GRPO w/ Entropy Adv | 27.32 | 63.33 | 65.49 | 92.50 | 85.89 | 96.00 | 45.32 | 65.44 | 49.92 | 70.67 | 54.79 | 77.59 |
| **+ MEDS** | **45.78** | **76.67** | **82.62** | **97.50** | **92.51** | **98.20** | **51.58** | **68.02** | **61.12** | **82.67** | **66.72** | **84.61** |

*   **Key Findings**:
    *   MEDS achieves the best average performance across all three model scales.
    *   The improvement is most pronounced on the larger Qwen3-8B model, suggesting benefits scale with model capability.
    *   On Qwen3-8B, MEDS improves pass@128 from 70.81 to 82.67 on OlympiadBench, a **17% relative gain**.
    *   Pass@k curves (Figure 4) show MEDS consistently matches or outperforms baselines across all k.

**Impact on Exploration Behavior**:
*   **Diversity Metrics**: Using LLM-based annotation, MEDS shows higher **Within-Step Diversity** (diversity of rollouts at the same step) and **Across-Step Diversity** (novelty of later rollouts compared to earlier ones) than DAPO.
*   **Representation Diversity**: The **Top-1 Eigen Ratio** of the logit covariance matrix is lower for MEDS, indicating a more uniform spread of representations and less concentration on a single error direction.
    > Top-1 Eigen Ratio $= \frac{\lambda_1}{\sum_{j=1}^d \lambda_j}$, where $\lambda_i$ are eigenvalues of the logit covariance matrix.

**Logits Reflect Reasoning Patterns**:
*   **Qualitative Case Study** (Figure 6): Responses with similar underlying reasoning patterns (even if producing different final answers) exhibit correlated layer-wise logit trajectories, especially in later layers.
*   **Quantitative Validation**: Logit-based clustering shows a **61.2% correlation** with error categories annotated by Claude-Haiku-4.5, confirming logits serve as an acceptable proxy for reasoning structure.

**Clustering Quality Matters**:
*   **Table 2: Comparison of different clustering feature constructions on Qwen2.5-Math-7B.** Using the last 14 layers (`MEDS-14`) yields the best performance.

| Method | AIME24 | AMC | MATH500 | Minerva | OlympiadBench | **Average** |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 | pass@1 | pass@128 |
| DAPO | 32.27 | 70.00 | 72.71 | 97.50 | 85.61 | 96.00 | 43.01 | 75.74 | 44.43 | 75.00 | 55.61 | 82.85 |
| MEDS-single cluster | 29.70 | 70.00 | 73.60 | 100.00 | 86.60 | 96.80 | 41.20 | 71.69 | 44.60 | 74.52 | 55.14 | 82.60 |
| MEDS-28-diff | 31.88 | 70.00 | 72.34 | 97.50 | 85.80 | 96.00 | 43.38 | 73.53 | 45.48 | 74.37 | 55.78 | 82.28 |
| MEDS-14-diff | 32.40 | 73.33 | 71.64 | 100.00 | 86.40 | 96.60 | 43.01 | 72.43 | 44.15 | 74.52 | 55.52 | 83.38 |
| MEDS-28 | 35.63 | 70.00 | 73.13 | 97.50 | 86.20 | 96.40 | 41.90 | 73.16 | 44.00 | 73.78 | 56.17 | 82.17 |
| **MEDS-14** | **34.32** | **76.67** | **74.38** | **97.50** | **86.33** | **96.00** | **42.51** | **74.26** | **44.80** | **75.56** | **56.47** | **84.00** |

*   **Key Findings**:
    *   Any variant using clustering information outperforms the DAPO baseline.
    *   The `single cluster` degenerate baseline performs worse, highlighting the importance of meaningful clustering.
    *   The ranking of downstream performance closely mirrors the ranking of clustering consistency with LLM annotations, showing that **better clustering quality leads to larger performance gains**.

## Theoretical and Practical Implications

**Theoretical Implications**:
*   Provides a formal justification for incorporating historical error information into reward shaping, proving it can non-decrease the expected return.
*   Establishes layer-wise logits as a valid, efficient proxy for capturing the model's internal reasoning state, connecting representation-level analysis with RL optimization.

**Practical Implications**:
*   **Efficient Exploration**: MEDS offers a low-overhead method (reusing existing logits) to significantly improve exploration diversity and mitigate error collapse in LLM RL.
*   **Reward Design**: Demonstrates the value of dynamic, memory-aware reward functions over static ones, opening a new direction for RL reward engineering.
*   **Model Analysis**: The success of logit-based clustering validates the use of internal model representations for understanding and steering model behavior beyond surface-level text.

## Conclusion

MEDS addresses the critical problem of error collapse in LLM reinforcement learning by introducing a memory-enhanced dynamic reward shaping framework. Key takeaways:

1.  **Core Innovation**: MEDS is the first method to explicitly incorporate historical error patterns into reward modeling, penalizing repetitive failures to encourage broader exploration.
2.  **Efficient Implementation**: By leveraging layer-wise logits as lightweight reasoning representations, it achieves significant gains with minimal computational overhead.
3.  **Consistent Improvements**: Empirical results across multiple models and benchmarks show MEDS consistently improves both pass@1 and pass@128 performance and increases behavioral diversity.
4.  **Validated Mechanism**: Theoretical analysis and empirical correlations confirm that logit-based clustering effectively captures reasoning patterns and that better clustering leads to better downstream performance.

**Limitations and Future Work**: The current logit aggregation strategies are relatively simple. Future work could explore more sophisticated feature extraction methods and apply the MEDS framework to other domains beyond mathematical reasoning.

**Code**: Available at https://github.com/Linxi000/MEDS

---

_Markdown view of https://picx.dev/p/GIpUP0, served by PicX — AI-generated visual whiteboard summaries of research papers._
