Paper status: completed

Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

Published:06/09/2025
Original LinkPDF
Price: 0.100000
Price: 0.100000
Price: 0.100000
6 readers
This analysis is AI-generated and may not be fully accurate. Please refer to the original paper.

TL;DR Summary

LeaF uses gradient-guided token pruning to remove confounding tokens, aligning student attention with causal focus from teachers, improving reasoning accuracy and interpretability across multiple benchmarks.

Abstract

Large language models (LLMs) have demonstrated significant improvements in contextual understanding. However, their ability to attend to truly critical information during long-context reasoning and generation still falls behind the pace. Specifically, our preliminary experiments reveal that certain distracting patterns can misdirect the model's attention during inference, and removing these patterns substantially improves reasoning accuracy and generation quality. We attribute this phenomenon to spurious correlations in the training data, which obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon may induce redundant reasoning processes, potentially resulting in significant inference overhead and, more critically, the generation of erroneous or suboptimal responses. To mitigate this, we introduce a two-stage framework called Learning to Focus (LeaF) leveraging intervention-based inference to disentangle confounding factors. In the first stage, LeaF employs gradient-based comparisons with an advanced teacher to automatically identify confounding tokens based on causal relationships in the training corpus. Then, in the second stage, it prunes these tokens during distillation to enact intervention, aligning the student's attention with the teacher's focus distribution on truly critical context tokens. Experimental results demonstrate that LeaF not only achieves an absolute improvement in various mathematical reasoning, code generation and multi-hop question answering benchmarks but also effectively suppresses attention to confounding tokens during inference, yielding a more interpretable and reliable reasoning model.

Mind Map

In-depth Reading

English Analysis

1. Bibliographic Information

1.1. Title

Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

1.2. Authors

Yiju Guo, Wenkai Yang, Zexu Sum, Ning Ding, Zhiyuan Liu, Yankai Lin. The authors are affiliated with Gaoling School of Artificial Intelligence, Renmin University of China; Department of Computer Science and Technology, Tsinghua University; and Baidu Inc. Their research background primarily lies in natural language processing, large language models, and machine learning, with a focus on improving LLM reasoning, efficiency, and interpretability.

1.3. Journal/Conference

The paper is published on arXiv, a preprint server for scientific papers. While not a peer-reviewed journal or conference proceeding in its current form, arXiv is a widely used platform in the AI community for disseminating cutting-edge research rapidly. Papers posted on arXiv are often submitted to top-tier conferences (e.g., ACL, EMNLP, NeurIPS, ICML) or journals later. The publication date suggests it is recent research.

1.4. Publication Year

2025 (Published on 2025-06-09T15:16:39.000Z)

1.5. Abstract

This paper addresses the challenge of Large Language Models (LLMs) struggling to focus on truly critical information during long-context reasoning and generation, often misdirected by distracting patterns attributed to spurious correlations in training data. This leads to redundant reasoning, inference overhead, and erroneous outputs. To mitigate this, the authors introduce Learning to Focus (LeaF), a two-stage framework. The first stage employs gradient-based comparisons with an advanced teacher model to automatically identify confounding tokens based on causal relationships. In the second stage, these tokens are pruned during distillation to align the student model's attention with the teacher's focus on critical context tokens. Experimental results demonstrate that LeaF achieves significant absolute improvements across various mathematical reasoning, code generation, and multi-hop question answering benchmarks. Furthermore, it effectively suppresses attention to confounding tokens, resulting in a more interpretable and reliable reasoning model.

Official Source: https://arxiv.org/abs/2506.07851 PDF Link: https://arxiv.org/pdf/2506.07851v2.pdf Publication Status: This is a preprint on arXiv.

2. Executive Summary

2.1. Background & Motivation

The core problem the paper aims to solve is the inability of Large Language Models (LLMs) to consistently focus on truly critical information during complex, long-context reasoning and generation tasks. Despite LLMs' advanced capabilities in contextual understanding and language generation, the authors' preliminary experiments reveal that distracting patterns in the input can misdirect the model's attention. When these distracting patterns are removed, reasoning accuracy and generation quality substantially improve.

This problem is important because spurious correlations within the vast training data of LLMs can obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon leads to several critical issues:

  1. Redundant Reasoning Processes: LLMs might engage in unnecessary computational steps, increasing inference overhead.

  2. Erroneous or Suboptimal Responses: The misdirected attention can cause the model to generate incorrect or less-than-ideal outputs.

  3. Lack of Interpretability: Models that rely on spurious correlations are harder to understand and debug.

    The paper's entry point or innovative idea is to address these issues by adopting a causal perspective. It views distracting patterns as spurious confounders that interfere with the LLM's reasoning. By identifying and mitigating the influence of these confounders, the paper aims to guide LLMs to focus on the truly relevant information, thereby improving their robustness, accuracy, and interpretability.

2.2. Main Contributions / Findings

The paper's primary contributions are:

  1. Introduction of Learning to Focus (LeaF) Framework: A novel two-stage framework that leverages intervention-based inference to disentangle confounding factors in LLM reasoning.

  2. Gradient-Guided Confounding Token Detection: LeaF introduces a method to automatically identify confounding tokens by comparing gradient sensitivities between a high-capacity teacher model and a student model. This approach is rooted in causal relationships identified within the training corpus.

  3. Causal Attention Distillation: LeaF incorporates span pruning of identified confounding tokens during a hybrid distillation process. This aligns the student model's attention with the teacher's focus distribution on critical context tokens, effectively enacting a causal intervention.

  4. Demonstrated Performance Improvements: LeaF achieves significant absolute performance gains across a diverse set of benchmarks:

    • Mathematical Reasoning: Absolute improvements in tasks like GSM8K, MATH-500, and OlympiadBench.
    • Code Generation: Absolute improvements on HumanEval+, LeetCode, and LivecodeBench.
    • Multi-hop Question Answering: Absolute improvements on HotpotQA, 2WikiMultiHopQA, and Musique.
  5. Enhanced Interpretability and Reliability: Experimental results, including attention visualizations and case studies, show that LeaF suppresses attention to confounding tokens, leading to a more interpretable and reliable reasoning model by encouraging focus on causally relevant information.

  6. Response-Level Pruning Strategy: The paper shows that extending pruning from instruction-level to response-level further enhances performance, indicating that distracting patterns exist and influence generation at both stages.

    These findings collectively demonstrate that explicitly mitigating spurious correlations through causal attention distillation is a highly effective approach to improve LLM reasoning capabilities, robustness, and interpretability.

3. Prerequisite Knowledge & Related Work

3.1. Foundational Concepts

To understand this paper, a reader should be familiar with the following fundamental concepts:

  • Large Language Models (LLMs): These are advanced artificial intelligence models trained on vast amounts of text data to understand, generate, and process human language. They typically employ transformer architectures with billions of parameters, enabling them to perform a wide range of tasks like text generation, translation, summarization, and question answering. Examples include GPT-3, LLaMA, and Qwen.
  • Contextual Understanding & Long-Context Reasoning: This refers to an LLM's ability to grasp the meaning of information within a given context (e.g., a long document, a conversation history) and use that understanding to perform reasoning tasks. Long-context reasoning specifically deals with contexts that are very long, posing challenges for LLMs to maintain focus and integrate information from disparate parts of the input.
  • Attention Mechanism: A core component of transformer-based LLMs. It allows the model to weigh the importance of different parts of the input sequence when processing each token. For example, when generating a word, the attention mechanism determines which other words in the input (or previously generated output) are most relevant. The foundational concept is self-attention where a token attends to all other tokens in the sequence.
    • The Attention mechanism computes a weighted sum of value vectors, where the weight assigned to each value is determined by the similarity between the query vector (representing the current token) and key vectors (representing other tokens).
    • $ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $
      • QQ: Query matrix, derived from the input embeddings, representing the current token's focus.
      • KK: Key matrix, derived from the input embeddings, representing other tokens' potential relevance.
      • VV: Value matrix, derived from the input embeddings, containing the information to be aggregated.
      • QKTQK^T: Dot product between query and key matrices, calculating similarity scores.
      • dk\sqrt{d_k}: Scaling factor, where dkd_k is the dimension of the key vectors, used to prevent large dot products from pushing the softmax function into regions with tiny gradients.
      • softmax\mathrm{softmax}: Normalizes the similarity scores into probability distributions, ensuring weights sum to 1.
      • VV: The value vectors, which are weighted and summed to produce the output of the attention layer.
  • Spurious Correlations: These are statistical relationships observed in data that appear to be causal but are not. In the context of LLMs, spurious correlations might arise when the model learns to associate certain superficial patterns or confounding tokens in the training data with specific outputs, even though these patterns do not represent a true causal link to the desired response. For example, if a dataset frequently includes a specific boilerplate phrase alongside correct answers, an LLM might learn to rely on the boilerplate rather than the actual reasoning content.
  • Knowledge Distillation (KD): A technique where a smaller, "student" model is trained to mimic the behavior of a larger, more powerful "teacher" model. The goal is to transfer the knowledge from the teacher to the student, often resulting in a smaller, faster, and more efficient student model that retains much of the teacher's performance. This typically involves minimizing the difference between the teacher's and student's output distributions (e.g., using Kullback-Leibler divergence).
    • Kullback-Leibler (KL) Divergence: A non-symmetric measure of the difference between two probability distributions. It quantifies how much one probability distribution diverges from a second, expected probability distribution.
      • $ D_{\mathrm{KL}}(P \parallel Q) = \sum_{i} P(i) \log\left(\frac{P(i)}{Q(i)}\right) $
      • PP: The true or reference probability distribution (e.g., teacher's output).
      • QQ: The approximate or learned probability distribution (e.g., student's output).
      • i\sum_i: Sums over all possible outcomes ii.
      • P(i), Q(i): Probabilities of outcome ii under distributions PP and QQ respectively.
      • A lower KL divergence indicates greater similarity between the two distributions.
  • Causal Inference / Structural Causal Models (SCM) / Directed Acyclic Graphs (DAGs): A framework developed by Judea Pearl for reasoning about cause-and-effect relationships. An SCM consists of a set of variables and structural equations that describe how each variable is determined by its direct causes. A DAG visually represents these causal relationships, with nodes as variables and directed edges as causal influences. In this paper, DAGs are used to model how input tokens (XX), confounding tokens (AA), and the model's output (YY) interact.
  • Gradients in Deep Learning: In machine learning, gradients represent the rate of change of a function (e.g., a loss function) with respect to its inputs (e.g., model parameters, input embeddings). They are central to optimization algorithms like stochastic gradient descent. In this paper, gradient sensitivity is used to measure how much a change in a specific input token affects the model's output or loss, indicating the token's importance or influence.
  • Counterfactuals: In causal inference, a counterfactual describes what would have happened if a particular event had not occurred, or if an intervention had been made. In this paper, counterfactual samples are created by pruning (removing) identified confounding tokens from an instruction. By comparing the model's behavior on original vs. counterfactual samples, the framework aims to isolate and learn true causal dependencies.

3.2. Previous Works

The paper positions itself within the context of several lines of research:

  • Chain-of-Thought (CoT) Knowledge Distillation: This paradigm aims to transfer complex reasoning abilities from large teacher models to smaller student models, often by distilling the step-by-step reasoning (Chain-of-Thought) rather than just the final answer.

    • CoTKD [41]: Pioneered using teacher-generated CoT explanations to fine-tune student models, enhancing their reasoning.
    • Data-focused approaches: These methods improve distillation by enhancing the quality and diversity of training data. Examples include:
      • CD [13] (Counterfactual Distillation): Reasoning through counterfactuals for smaller models.
      • SCORE [56]: Small language models need strong verifiers to self-correct reasoning.
      • Skintern [28]: Internalizing symbolic knowledge for distilling better CoT capabilities.
    • Model-focused approaches: These focus on improving model architectures and inference strategies for efficiency and reasoning. Examples include:
      • PRR [41] (Problem-Reduction Reasoning): A method that breaks down complex problems.
      • ATM [7] (Adaptive Thinking Models): Distilling reasoning ability with adaptive thinking.
    • Differentiation: While these methods primarily focus on output imitation and improving data/model structures, LeaF explicitly distills the teacher's ability to focus on critical information during inference, enabling context-aware reasoning by targeting spurious correlations.
  • Critic Token Identification: This area explores strategies to identify and mitigate redundant or less important tokens during LLM reasoning.

    • LLMLingua [23]: Relies on model's self-assessment for token importance, potentially introducing biases.
    • RHO-1 [29]: Introduces Selective Language Modeling (SLM) during Supervised Fine-Tuning to prioritize informative tokens.
    • TokenSkip [47]: Selectively skips low-impact tokens in CoT stages to compress reasoning paths.
    • cDPO [30]: Enhances Direct Preference Optimization by isolating critical tokens through contrastive learning.
    • Differentiation: Most existing methods concentrate on output tokens or reasoning steps, overlooking the influence of prior contextual information from the instruction phase. LeaF, in contrast, uses an advanced teacher model to identify confounding tokens based on gradient differences in both instructions and responses, establishing stronger connections between instruction and output for more holistic filtering.
  • Reasoning Consistency: This research line focuses on improving the stability and reliability of LLM outputs, especially in multi-step reasoning.

    • Self-Consistency [46]: Samples multiple reasoning chains and uses majority voting to stabilize final answers.
    • Adaptive Consistency [2] and Early Scoring Self-Consistency [27]: Introduce stopping criteria to reduce inference costs.
    • Reasoning Aware Self-Consistency [45]: Weights sample quality and reasoning-path importance.
    • Differentiation: These methods primarily aim to stabilize the final answer and reasoning paths. LeaF, however, defines consistency as both answer stability and context adherence. It uses a distillation strategy to systematically suppress misleading signals, enhancing the student's focus on key contextual information and promoting more robust context adherence throughout the generation process.

3.3. Technological Evolution

The field of LLM reasoning has evolved from:

  1. Basic Language Models: Focused on next-token prediction, often struggling with complex reasoning.

  2. Emergence of Transformers and Attention: Revolutionized contextual understanding.

  3. Chain-of-Thought (CoT) Prompting: Enabled LLMs to perform multi-step reasoning by generating intermediate steps.

  4. Knowledge Distillation for Reasoning: Applied KD to transfer CoT capabilities from large to small models.

  5. Focus on Efficiency and Robustness: Development of methods for token pruning, selective modeling, and consistency techniques to optimize performance and reliability.

    This paper's work (LeaF) fits into the latest stage of this evolution by addressing a critical, often overlooked aspect: the intrinsic issue of spurious correlations in training data leading to misdirected attention. It moves beyond simply mimicking teacher outputs or pruning less important tokens to specifically identify and intervene on causal confounders, thereby enhancing the fundamental attention mechanism and causal reasoning of student models.

3.4. Differentiation Analysis

Compared to existing methods, LeaF offers several core differences and innovations:

  • Causal Perspective on Confounding Factors: Unlike many previous methods that focus on redundant or less informative tokens, LeaF explicitly frames distracting patterns as spurious confounders in a causal framework. This theoretical grounding provides a principled way to identify and mitigate their influence.
  • Gradient-Guided Teacher-Student Comparison: LeaF uses gradient sensitivity from both a powerful teacher model and a student model to identify confounding tokens. This goes beyond perplexity-based or self-assessment methods, leveraging the teacher's superior reasoning to pinpoint where the student's attention is spuriously drawn.
  • Intervention-Based Distillation with Counterfactuals: The framework generates counterfactual samples by span pruning the identified confounding tokens. The hybrid distillation loss then explicitly aligns the student's attention by contrasting its behavior on original and counterfactual samples, forcing it to learn genuine causal dependencies. This is a direct intervention on the causal graph.
  • Holistic Token Filtering (Instruction + Response): While some methods focus on output tokens, LeaF considers confounding tokens in both instruction-level and response-level contexts, providing a more comprehensive approach to guide attention throughout the entire reasoning and generation process.
  • Focus on Attention Alignment, not just Output Imitation: The core goal is to align the student's attention with the teacher's focus distribution on truly critical context tokens, rather than merely imitating the teacher's final output or intermediate steps. This aims to improve the underlying reasoning process itself.

4. Methodology

The Learning to Focus (LeaF) framework aims to mitigate the adverse effects of distracting patterns (termed confounding tokens) on LLM reasoning by explicitly identifying and intervening on them. It leverages a causal perspective to guide a student model's attention toward truly critical information, drawing knowledge from a more advanced teacher model.

4.1. Principles

The core idea of LeaF is rooted in causal inference, specifically Pearl's Structural Causal Model (SCM) [39]. The method posits that confounding tokens introduce spurious correlations that mislead the LLM's reasoning process. By intervening on these confounders—effectively removing their influence—the student model can learn to focus on authentic causal instruction-response relationships. This is achieved through gradient-based comparisons to detect confounders and a hybrid distillation strategy that encourages the student to mirror the teacher's attention to causally relevant information. The theoretical basis is that by blocking non-causal influences from confounding tokens, the student model learns attention patterns grounded in the true causal structure, leading to improved robustness and interpretability.

4.2. Core Methodology In-depth (Layer by Layer)

4.2.1. Causal Framework

The paper models the reasoning process using a Directed Acyclic Graph (DAG). As shown in Figure 3, the input tokens XX influence the model's output YY. A subset of these input tokens, denoted as confounding tokens AA, introduces spurious correlations. These confounders AA simultaneously influence the true causal input XAX \setminus A (represented as XiX_i in the formula below, where XiX_i is a specific part of the input excluding AA) and the output YY, thereby distorting the observed relationship between XX and YY.

Figure 3: Causal graph of the reasoning process. \(X\) represents the input prompt, and \(Y\) denotes the model's output. A subset of tokens in \(X\) , identified as confounding tokens `( A )` ,introduces…
该图像是论文中展示的图3因果图示意图,描述了推理过程中输入提示 XX 和模型输出 YY 之间的因果关系。图中标注的混淆令牌 AA 引入了干扰的伪相关,该方法通过检测并屏蔽 AA,消除 AAYY 的伪因果边,恢复真实因果依赖。

Figure 3: Causal graph of the reasoning process. XX represents the input prompt, and YY denotes the model's output. A subset of tokens in XX , identified as confounding tokens ( A ) ,introduces spurious correlations that disrupt the reasoning process. Our method detects and masks AA , effectively eliminating the spurious edge from AA to YY and restoring the true causal dependency.

The observed conditional distribution, which includes the biasing effect of confounding tokens AA, is given by: $ \begin{array} { l } { P ( Y \mid X _ { i } = x ) = } \ { \displaystyle \sum _ { A } P ( Y \mid X _ { i } = x , A ) P ( A \mid X _ { i } = x ) , } \end{array} $

  • P(YXi=x)P(Y \mid X_i = x): The probability of observing output YY given an input Xi=xX_i = x. This is the observed conditional distribution that LLMs typically learn.

  • Xi=xX_i = x: A specific input (or part of the input) instance xx.

  • AA: The set of confounding tokens.

  • P(YXi=x,A)P(Y \mid X_i = x, A): The probability of output YY given both the input Xi=xX_i = x and the presence of confounding tokens AA.

  • P(AXi=x)P(A \mid X_i = x): The probability of confounding tokens AA being present given the input Xi=xX_i = x.

  • A\sum_A: Summation over all possible states or values of the confounding tokens AA.

    This formula shows that the observed relationship between XiX_i and YY is an average over the influence of AA, meaning AA introduces an indirect influence on YY that is not part of the true causal instruction-response relationship. This observed conditional distribution P(YXi=x)P(Y \mid X_i = x) deviates from the interventional distribution P(Ydo(Xi))P(Y \mid \mathsf{do}(X_i)), which represents the causal effect of XiX_i on YY if the influence of AA were explicitly removed (i.e., do(Xi)do(X_i) means setting XiX_i to xx while holding other factors constant, including breaking spurious links).

To address this bias, the paper proposes causal pruning, which involves detecting and masking the confounding tokens AA. This effectively eliminates the spurious edge from AA to YY (dashed arrow in Figure 3), encouraging the student model to learn attention patterns based on the true causal structure.

4.2.2. LeaF: Learning to Focus Framework

The LeaF framework, illustrated in Figure 4, consists of two main stages:

该图像是示意图,展示了论文中提出的两阶段框架“Learning to Focus (LeaF)”的流程。第一阶段通过梯度比较识别混淆因子并进行裁剪;第二阶段进行因果注意力蒸馏,令学生模型学习聚焦重要上下文,提升推理准确性和模型可靠性。
该图像是示意图,展示了论文中提出的两阶段框架“Learning to Focus (LeaF)”的流程。第一阶段通过梯度比较识别混淆因子并进行裁剪;第二阶段进行因果注意力蒸馏,令学生模型学习聚焦重要上下文,提升推理准确性和模型可靠性。

Stage 1: Confounding Token Detection
Figure 4: Method Overview. The training pipeline comprises two key stages: (1) Confounding Token Detection: gradient-based comparisons between an advanced teacher model and the student model are used to identify confounding tokens in the training samples and constructs counterfactual samples by pruning these tokens; and (2) Causal Attention Distillation: prune identified confounders respectively during training to align the student's attention with the teacher's and capture casual relationships. This targeted intervention steers the model toward actual causal dependencies, enhancing both robustness and interpretability.

  1. Confounding Token Detection (Section 2.2.1): In this stage, confounding tokens are identified using gradient-based comparisons between a teacher model (θT\pmb{\theta}_T) and a student model (θS\pmb{\theta}_S). Once identified, counterfactual samples are constructed by pruning these tokens.
  2. Causal Attention Distillation (Section 2.2.2): This stage uses a hybrid distillation loss that operates on both original and counterfactual samples. This loss aligns the student's output distribution and attention with the teacher's, specifically encouraging the student to capture causal dependencies by learning from the teacher's behavior before and after the confounding token intervention.

4.2.2.1. Confounding Token Detection

To identify confounding tokens AA, LeaF employs a gradient-based approach [42, 43] to quantify the influence of each token xiXx_i \in X on the model's output YY. This process focuses on instances where the student model makes an incorrect prediction, but the more capable teacher model correctly handles it, indicating a potential misdirection of the student's attention by confounding tokens.

For each token xix_i, the gradient sensitivity (absolute gradient of the loss with respect to the token embedding) is calculated for both the teacher and student models: $ g _ { i } ^ { ( T ) } = \left| \frac { \partial \ell ( x _ { i } \mid X ; \pmb { \theta } _ { T } ) } { \partial x _ { i } } \right| , \quad g _ { i } ^ { ( S ) } = \left| \frac { \partial \ell ( x _ { i } \mid X ; \pmb { \theta } _ { S } ) } { \partial x _ { i } } \right| . $

  • gi(T)g_i^{(T)}: The absolute gradient sensitivity of the teacher model to token xix_i.

  • gi(S)g_i^{(S)}: The absolute gradient sensitivity of the student model to token xix_i.

  • (xiX;θ)\ell(x_i \mid X; \pmb{\theta}): The loss function of the model θ\pmb{\theta} (teacher θT\pmb{\theta}_T or student θS\pmb{\theta}_S) with respect to the input token xix_i within the context XX. This loss is typically calculated between the model's predicted logits and the gold reference (true answer).

  • xi\frac{\partial \ell}{\partial x_i}: The partial derivative of the loss with respect to the embedding of token xix_i. This measures how sensitive the loss is to changes in xix_i.

  • |\cdot|: Absolute value, indicating the magnitude of sensitivity regardless of direction.

    To allow for a fair comparison between models with potentially different gradient scales, these sensitivity values are normalized using min-max normalization: $ \hat { g } _ { i } ^ { ( T ) } = \frac { g _ { i } ^ { ( T ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( T ) } } { \mathrm { m a x } _ { j } g _ { j } ^ { ( T ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( T ) } } , \quad \hat { g } _ { i } ^ { ( S ) } = \frac { g _ { i } ^ { ( S ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( S ) } } { \mathrm { m a x } _ { j } g _ { j } ^ { ( S ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( S ) } } . $

  • g^i(T)\hat{g}_i^{(T)}, g^i(S)\hat{g}_i^{(S)}: The normalized gradient sensitivities for token xix_i for the teacher and student models, respectively. These values will be between 0 and 1.

  • minjgj(T)\min_j g_j^{(T)}, maxjgj(T)\max_j g_j^{(T)}: The minimum and maximum gradient sensitivities across all tokens jj in the input for the teacher model.

  • minjgj(S)\min_j g_j^{(S)}, maxjgj(S)\max_j g_j^{(S)}: The minimum and maximum gradient sensitivities across all tokens jj in the input for the student model.

    To identify confounding tokens, the difference in normalized gradient sensitivity between the teacher and student is computed for each token: $ \Delta \hat { g } _ { i } = \hat { g } _ { i } ^ { ( T ) } - \hat { g } _ { i } ^ { ( S ) } . $

  • Δg^i\Delta \hat{g}_i: The difference in normalized gradient sensitivity for token xix_i. A negative value indicates that the student model is more sensitive to xix_i than the teacher, while a positive value indicates the opposite.

    A token xix_i is classified as a Confounding Token if two conditions are met:

  1. Sensitivity Discrepancy: It receives significant attention from the student model but negligible attention from the teacher during inference. This is formalized by normalizing the gradient difference and comparing it to a threshold: $ \frac { \Delta \hat { g } _ { i } - \operatorname* { m i n } _ { j } \Delta \hat { g } _ { j } } { \operatorname* { m a x } _ { j } \Delta \hat { g } _ { j } - \operatorname* { m i n } _ { j } \Delta \hat { g } _ { j } } \leq \tau _ { \mathrm { c o n f o u n d e r } } , $

    • Δg^iminjΔg^jmaxjΔg^jminjΔg^j\frac{\Delta \hat{g}_i - \min_j \Delta \hat{g}_j}{\max_j \Delta \hat{g}_j - \min_j \Delta \hat{g}_j}: Min-max normalized value of the gradient difference Δg^i\Delta \hat{g}_i. This ensures the threshold works consistently across different instances.
    • τconfounder\tau_{\mathrm{confounder}}: A predefined threshold (determined via validation set analysis). A low value for this normalized difference indicates that Δg^i\Delta \hat{g}_i is small or negative, meaning the teacher is less sensitive to this token than the student, or even actively ignores it.
    • Intuitively, this condition identifies tokens where the student model's attention is disproportionately high compared to the teacher, suggesting a spurious dependency.
  2. Correct Prediction After Removal: The removal of token xix_i results in correct predictions from both the teacher and student models. This ensures that the identified token is indeed confounding and not causally essential for correct reasoning.

    Pruning Strategies: The paper explores two strategies for constructing counterfactual samples by removing identified confounding tokens from the instruction XX:

    Figure 5: Illustration of Collective Pruning and Span Pruning. 该图像是图示图,展示了论文中图5所示的集体剪枝(Collective Pruning)和区间剪枝(Span Pruning)方法。通过不同颜色区域标识有效token、混淆token及干扰pattern区域,说明剪枝操作在训练语料中的具体应用。

Figure 5: Illustration of Collective Pruning and Span Pruning.

  1. Collective Pruning: Removes the entire set of all identified confounders AA from XX, resulting in XAX \setminus A. (See Figure 5, right side, top example: all blue shaded tokens removed simultaneously).

  2. Span Pruning: Removes only one contiguous span of confounding tokens AiA_i at a time, yielding XAiX \setminus A_i. (See Figure 5, right side, bottom examples: only a single blue shaded span is removed in each counterfactual).

    Preliminary experiments (Appendix C) show that Span Pruning outperforms Collective Pruning, as removing all distracting patterns simultaneously can disrupt the sentence integrity. Therefore, LeaF adopts Span Pruning to construct counterfactual samples: $ \mathcal { D } _ { \mathrm { p r u n e d } } = \left{ ( X \setminus A _ { i } , y ) \right} _ { i = 1 } ^ { k } , $

  • Dpruned\mathcal{D}_{\mathrm{pruned}}: The set of counterfactual samples generated.

  • (XAi,y)(X \setminus A_i, y): A single counterfactual sample, where a distinct confounding span AiA_i has been removed from the original instruction XX, and yy is the corresponding ground-truth output.

  • kk: The total number of distinct confounding spans identified in the original instruction XX.

    This augmentation creates multiple counterfactual examples for a single original instruction, encouraging the model to learn reasoning paths that are invariant to specific confounders.

4.2.2.2. Causal Attention Distillation

After generating both original and counterfactual samples, LeaF optimizes two complementary distillation objectives to steer the student model towards true causal dependencies:

  1. Standard Distillation (Lkd\mathcal{L}_{kd}): This objective aligns the student's output distribution (pSp_S) with the teacher's output distribution (pTp_T) on the original instructions XX. This is a typical knowledge distillation loss. $ { \mathcal { L } } _ { k d } = D _ { \mathrm { K L } } \big ( p _ { T } ( y \mid X ) \parallel p _ { S } ( y \mid X ) \big ) , $

    • DKL()D_{\mathrm{KL}}(\cdot \parallel \cdot): The Kullback-Leibler divergence (explained in Section 3.1).
    • pT(yX)p_T(y \mid X): The probability distribution of the teacher model's output yy given the original instruction XX.
    • pS(yX)p_S(y \mid X): The probability distribution of the student model's output yy given the original instruction XX.
  2. Counterfactual Distillation (Lcd\mathcal{L}_{cd}): This objective aligns the student's output distribution (pSp_S) with the teacher's output distribution (pTp_T) on the counterfactual instructions XAX \setminus A (where confounders have been pruned). This is the key causal intervention component. $ { \mathcal { L } } _ { c d } = D _ { \mathrm { K L } } \big ( p _ { T } ( y \mid X \setminus A ) \parallel p _ { S } ( y \mid X \setminus A ) \big ) . $

    • pT(yXA)p_T(y \mid X \setminus A): The probability distribution of the teacher model's output yy given the counterfactual instruction XAX \setminus A.

    • pS(yXA)p_S(y \mid X \setminus A): The probability distribution of the student model's output yy given the counterfactual instruction XAX \setminus A.

      These two objectives are combined into a hybrid distillation loss using a weighting factor λ\lambda: $ \mathcal { L } = \lambda \mathcal { L } _ { k d } + \left( 1 - \lambda \right) \mathcal { L } _ { c d } , $

  • L\mathcal{L}: The total hybrid distillation loss that the student model minimizes during training.
  • λ\lambda: A hyperparameter in the range [0, 1] that controls the trade-off between Standard Distillation and Counterfactual Distillation.
    • If λ=1\lambda=1, only Standard Distillation is used.

    • If λ=0\lambda=0, only Counterfactual Distillation is used.

    • Values between 0 and 1 allow balancing both objectives.

      This composite loss encourages the student to preserve semantic knowledge from the teacher on original inputs while simultaneously enforcing genuine causal dependencies by observing the teacher's behavior when confounding factors are removed.

Response Splitting Strategies: Beyond instruction-level pruning, the paper also considers applying the confounding token detection and pruning to the model's generated response. This is important because previously generated tokens can also act as contextual input for subsequent generation steps and might contain misleading patterns.

Figure 6: Illustration of Response Splitting Strategies: Language CoT, Instruct-level Pruning, and Response-level Pruning (2-segment and 3-segment splits). Highlighted white areas represent the input…
该图像是一个示意图,展示了图6中不同的回复拆分策略,包括语言CoT、指令级剪枝和响应级剪枝(2段和3段拆分)。图中用白色高亮表示输入部分,蓝色下划线表示用于计算交叉熵损失的输出部分。

Figure 6: Illustration of Response Splitting Strategies: Language CoT, Instruct-level Pruning, and Response-level Pruning (2-segment and 3-segment splits). Highlighted white areas represent the input, and blue underlined areas represent the outputs used for cross-entropy loss computation.

The paper considers two variants for response splitting:

  1. Instruct-level Pruning: Confounding tokens are detected and pruned only in the instructions. The LeaF framework is applied to these instruction-level pruned samples. (Figure 6, middle)
  2. Both Instruct- and Response-level Pruning: Confounding tokens are detected and pruned in both the instructions and preceding generations (the model's partial output). This helps the model produce more accurate continuations. (Figure 6, right)
    • 2-segment splits: The response is split into two segments, and confounding tokens are identified and pruned within each.

    • 3-segment splits: The response is split into three segments, allowing for more granular detection and pruning of confounding tokens.

      This approach aims to address the dynamic nature of distracting patterns that can emerge not just in the initial prompt but also during the model's own generation process, further enhancing reasoning capabilities.

5. Experimental Setup

5.1. Datasets

The experiments evaluate LeaF on mathematical reasoning, code generation, and multi-hop question answering tasks.

Training Datasets:

  • Mathematical Reasoning:
    • NuminaMath-CoT [26]: A large dataset used to ensure models encounter an equal number of confounding tokens across tasks. The training set consists of 30k instances, randomly selected (7.5k each) from:
      • Olympiads [16]: Advanced math and physics problems, filtered for pure-text tasks.
      • AMC_AIME [26]: American Mathematics Competitions and American Invitational Mathematics Examination problems.
      • GSM8K [8]: Grade-school level math word problems.
      • MATH [17]: Competition-style math problems across various topics.
  • Code Generation:
    • AceCode-87K [53]: A dataset for competitive coding, from which a subset of 120k instances is randomly selected.
  • Multi-hop Question Answering:
    • KILT [40] datasets provided in Helmet [52]: A benchmark for knowledge-intensive language tasks. The training set is constructed by merging data from:
      • HotpotQA [50]: Requires reasoning over multiple Wikipedia paragraphs.
      • NQ [1]: Natural Questions, typically single-paragraph answers.
      • PopQA [36]: Factoid questions.
    • Totaling 3k annotated samples, drawn equally from HotpotQA, NQ, and PopQA, where each query is linked to gold passages containing the answers.

Evaluation Datasets:

  • Mathematical Reasoning:
    • GSM8K [8]: 8500 grade-school-level word problems, requiring 2-8 steps of basic arithmetic. Assesses multi-step reasoning in natural language.
    • MATH-500 [17]: A subset of the MATH dataset (12500 competition-style problems) focusing on diverse mathematical topics.
    • OlympiadBench [16]: A challenging benchmark for Olympiad-level math and physics problems. The text-only subset of 674 tasks is used for advanced symbolic reasoning.
  • Code Generation:
    • HumanEval+ [32]: Extends the original HumanEval with additional Python programming tasks and augmented unit tests, targeting functional correctness.
    • LeetCode [9]: Samples real-world algorithmic challenges (arrays, trees, dynamic programming) to assess the ability to generate correct and efficient solutions.
    • LiveCodeBench (v4) [22]: A large-scale suite of real-world coding tasks with comprehensive unit tests and human preference annotations, for functional accuracy and coding style.
  • Multi-hop Question Answering:
    • HotpotQA [50]: Requires reasoning over multiple Wikipedia paragraphs to answer complex questions.
    • 2WikiMultiHopQA [19]: Multi-hop questions generated from Wikipedia, requiring reasoning across at least two documents.
    • Musique [44]: Provides multi-hop questions with fine-grained decomposition and supporting evidence, testing compositional reasoning and factual consistency.

5.2. Evaluation Metrics

For every evaluation metric, the following structure is provided:

  • Accuracy:

    1. Conceptual Definition: Accuracy measures the proportion of correctly predicted instances out of the total number of instances. It is a common metric for classification tasks where the output is a discrete value. In reasoning tasks, it typically refers to the percentage of problems for which the model provides the exact correct answer.
    2. Mathematical Formula: $ \text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} $
    3. Symbol Explanation:
      • Number of Correct Predictions: The count of instances where the model's output matches the ground truth.
      • Total Number of Predictions: The total count of instances evaluated.
  • Pass@k:

    1. Conceptual Definition: Used in code generation tasks, pass@k measures the probability that at least one of kk generated code samples passes the unit tests. If a model generates kk solutions for a problem, and at least one is functionally correct, the problem is considered solved. It accounts for the stochastic nature of code generation and how many attempts are needed to get a correct solution. The common variant used in the paper is pass@1 for HumanEval+ and LeetCode, and pass@10 for LivecodeBench.
    2. Mathematical Formula: For a given problem, let NN be the number of test cases, CC be the number of generated code samples, and PsuccessP_{success} be the probability that a single sample passes. $ \text{Pass@k} = \frac{1}{\text{num_problems}} \sum_{problem=1}^{\text{num_problems}} \left(1 - \frac{\binom{C-P_{correct}}{k}}{\binom{C}{k}}\right) $ Note: The common formulation for pass@k for competitive programming tasks is typically simpler in evaluation, where kk samples are generated and if any one of them passes, the problem is counted as solved. The formula presented here is a more rigorous statistical estimation. Given the paper uses "pass@1" and "pass@10", it implies the simpler interpretation where kk attempts are made. A more common and simpler way to calculate it, if you generate kk samples and check each, is the empirical success rate. $ \text{Pass@k} = \frac{\text{Number of problems with at least one correct sample}}{\text{Total number of problems}} $ For a more robust statistical definition typically used (when kk samples are generated, and a 'correct' count is known): Let nn be the number of solutions generated for a problem, and cc be the number of those nn solutions that are correct. The probability that none of kk randomly chosen solutions (from the nn generated) are correct is (nck)(nk)\frac{\binom{n-c}{k}}{\binom{n}{k}}. Thus, the probability that at least one is correct is 1(nck)(nk)1 - \frac{\binom{n-c}{k}}{\binom{n}{k}}. Pass@k is the average of this probability over all problems. $ \text{Pass@k} = \frac{1}{\text{num_problems}} \sum_{problem=1}^{\text{num_problems}} \left(1 - \frac{\binom{\text{generated_samples} - \text{correct_samples}}{\text{k}}}{\binom{\text{generated_samples}}{\text{k}}}\right) $
    3. Symbol Explanation:
      • num_problems: Total number of coding problems in the benchmark.
      • generated_samples: Total number of code samples generated for a particular problem.
      • correct_samples: Number of functionally correct samples among generated_samples for a particular problem.
      • kk: The number of samples considered (e.g., 1 or 10).
      • (nk)\binom{n}{k}: The binomial coefficient, representing "n choose k", i.e., the number of ways to choose kk items from a set of nn items.
  • EM (Exact Match):

    1. Conceptual Definition: Exact Match is a strict metric used in question answering that measures whether the model's generated answer string is character-for-character identical to any of the ground-truth answer strings. It's often used for tasks with short, factual answers.
    2. Mathematical Formula: $ \text{EM} = \frac{1}{\text{N}} \sum_{i=1}^{N} \mathbb{I}(\text{model_answer}_i == \text{gold_answer}_i) $
    3. Symbol Explanation:
      • NN: Total number of questions.
      • I()\mathbb{I}(\cdot): The indicator function, which is 1 if the condition inside is true, and 0 otherwise.
      • model_answeri\text{model\_answer}_i: The answer generated by the model for question ii.
      • gold_answeri\text{gold\_answer}_i: The ground-truth answer for question ii. (Often there are multiple valid gold answers; if the model matches any, it's correct).
  • F1 Score:

    1. Conceptual Definition: The F1 score is a measure of a test's accuracy, often used for question answering tasks, especially when answers can be phrases or spans of text. It is the harmonic mean of precision and recall.
      • Precision measures how many of the model's predicted tokens are correct (i.e., also in the gold answer).
      • Recall measures how many of the ground-truth answer tokens were captured by the model's prediction. The F1 score provides a single metric that balances both precision and recall, being high only when both are high.
    2. Mathematical Formula: $ F1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} $ where, for text spans: $ \text{Precision} = \frac{\text{Number of overlapping tokens}}{\text{Number of tokens in model_answer}} $ $ \text{Recall} = \frac{\text{Number of overlapping tokens}}{\text{Number of tokens in gold_answer}} $
    3. Symbol Explanation:
      • Number of overlapping tokens: The count of unique tokens (often after tokenization and normalization) that are present in both the model's answer and the gold-truth answer.
      • Number of tokens in model_answer: The total count of unique tokens in the model's generated answer.
      • Number of tokens in gold_answer: The total count of unique tokens in the ground-truth answer.
  • Jaccard Similarity:

    1. Conceptual Definition: Jaccard Similarity (also known as the Jaccard Index or Intersection over Union) is a statistic used for gauging the similarity and diversity of sample sets. It is defined as the size of the intersection divided by the size of the union of the sample sets. In the context of text, it measures the overlap between two sets of words or tokens.
    2. Mathematical Formula: For two sets AA and BB: $ J(A, B) = \frac{|A \cap B|}{|A \cup B|} $
    3. Symbol Explanation:
      • AA: The set of tokens in the first text (e.g., student model's response).
      • BB: The set of tokens in the second text (e.g., ground-truth response or teacher model's response).
      • AB|A \cap B|: The number of common tokens between sets AA and BB.
      • AB|A \cup B|: The total number of unique tokens across both sets AA and BB.

5.3. Base Models and Baselines

Base Models (Student Models):

  • LLaMA Family:
    • LLaMA3.2-1B-Instruct [34]
    • LLaMA3.2-3B-Instruct [34]
  • Qwen Family:
    • Qwen2.5-Math-1.5B [48]

Teacher Models:

  • For LLaMA-based experiments: LLaMA3.3-70B-Instruct [34]
  • For Qwen-based experiments: Qwen2.5-72B-Instruct [48] These are chosen as high-capacity models that are expected to demonstrate strong reasoning capabilities and accurately identify critical information.

Baselines:

  • Instruct Model (Pre-KD): The student model before any knowledge distillation (i.e., the original instruction-tuned version). This serves as a starting point to measure improvements.
  • KD w/o Mask (Standard Knowledge Distillation): This is the primary baseline. It involves standard knowledge distillation where the student model is trained to mimic the teacher's outputs without any confounding token identification or pruning.
    • For math tasks, a CoT-based variant is used.
    • For code generation and multi-hop QA, a vanilla variant is used.

5.4. Training and Evaluation Settings

Training Settings:

  • Framework: Models are trained using the Alpaca-LoRA framework, which typically involves Low-Rank Adaptation (LoRA) for efficient fine-tuning.
  • Distillation: Full-parameter logits knowledge distillation is used, implying that the student is trained to match the teacher's full output probability distribution (logits) rather than just the hard labels.
  • Learning Rate Schedule: A cosine learning rate schedule is employed, starting with a maximum learning rate of 10510^{-5}.
  • Epochs: Training is conducted for 3 epochs.
  • Batch Size:
    • LLaMA-based models: 64
    • Qwen-based models: 32
  • Detailed Hyperparameters: Further details are provided in Appendix J, including maximum sequence length (4096 for LLaMA-1B/Qwen-1.5B, 3000 for LLaMA-3B) and warmup steps (5).

Evaluation Settings:

  • Decoding Strategy: Greedy decoding is used for both teacher and student models during evaluation.
  • Generation Length:
    • Code tasks: Maximum generation length of 1024 tokens.
    • Math tasks: Maximum generation length of 16,384 tokens (reflecting the potentially longer reasoning chains).
  • Chat Templates: Official chat templates are followed during inference to ensure consistency with how models are typically used.
  • Evaluation Frameworks:
    • Math tasks: Evaluated using a modified Step-DPO framework [24].
    • Code generation: Evaluated using EvalPlus [32] (for HumanEval+ and LeetCode) and Skythought-Evals framework [25] (for LiveCodeBench).
    • Multi-hop QA: Evaluated using the LongMab-PO framework [12], with datasets from LongBench [5].

6. Results & Analysis

6.1. Core Results Analysis

The main results demonstrate LeaF's effectiveness in enhancing LLM reasoning across various tasks.

The following are the results from Table 1 of the original paper:

Model MathBench CodeBench
GSM8K MATH-500 Olympiad- Human- Leet- Livecode-
Bench Avg. Eval+ Code Bench Avg.
Teacher Model
LLaMA3.3-70B-Instruct 95.60 70.40 36.50 67.50 78.05 53.90 45.02 58.99
Qwen2.5-72B-Instruct 95.45 73.80 41.25 70.17 81.71 69.40 54.42 68.51
LLaMA3.2-1B-Instruct
Instruct Model (Pre-KD) 44.88 24.20 5.79 24.96 29.27 7.22 9.68 15.39
KD w/o Mask 56.79 33.40 8.90 33.03 32.32 6.11 13.74 17.39
LeaF (Instr Mask) 57.70 35.40 10.09 34.40 39.02 6.67 13.60 19.76
LeaF (Instr & Resp Mask) 58.98 35.20 9.94 34.71 39.63 7.22 12.48 19.77
LLaMA3.2-3B-Instruct
Instruct Model (Pre-KD) 76.88 42.80 13.20 44.29 48.78 13.89 20.34 27.67
KD w/o Mask 82.87 49.00 18.99 50.29 54.88 16.67 24.12 31.89
LeaF (Instr Mask) 83.09 51.80 20.77 51.88 55.49 19.44 25.39 33.44
LeaF (Instr & Resp Mask) 84.69 52.40 22.55 53.21 56.10 21.67 25.81 34.53
Qwen2.5-Math-1.5B
Base Model (Pre-KD) 65.20 41.40 21.96 42.85 35.37 6.67 1.26 14.43
KD w/o Mask 82.18 67.80 31.16 60.38 41.46 7.78 10.10 19.78
LeaF (Instr Mask) 84.69 68.60 32.79 62.03 42.68 9.94 10.80 20.97
LeaF (Instr & Resp Mask) 85.29 70.60 31.75 62.54 43.29 9.94 13.04 21.92

The following are the results from Table 2 of the original paper:

Model 2WikiMultiHopQA Musique HotpotQA Avg.
EM F1 SubEM EM F1 SubEM EM F1 SubEM
LLaMA3.2-3B-Instruct
Instruct Model re-22.50 34.73 45.50 9.50 16.56 18.50 19.50 32.67 36.00 26.16
KD w/o Mask 43.50 51.93 53.00 20.50 28.56 22.50 39.00 49.30 43.00 39.03
LeaF (Instr Mask) 46.50 53.89 55.00 2700 33.00 28.00 40.50 52.06 44.50 42.27

Several conclusions can be drawn from these tables:

  1. General Improvement with Distillation: Both standard knowledge distillation (KD w/o Mask) and LeaF significantly improve the performance of smaller student models (LLaMA3.2-1B/3B-Instruct, Qwen2.5-Math-1.5B) across almost all math, code, and multi-hop QA tasks compared to their Pre-KD versions. This validates the effectiveness of knowledge distillation in transferring capabilities from large teacher models.
  2. LeaF Outperforms Standard KD: LeaF consistently outperforms standard knowledge distillation (KD w/o Mask) across all benchmarks.
    • MathBench (Avg. Accuracy): For LLaMA3.2-1B, KD (33.03%) -> LeaF Instr Mask (34.40%) -> LeaF Instr & Resp Mask (34.71%). For LLaMA3.2-3B, KD (50.29%) -> LeaF Instr Mask (51.88%) -> LeaF Instr & Resp Mask (53.21%). For Qwen2.5-Math-1.5B, KD (60.38%) -> LeaF Instr Mask (62.03%) -> LeaF Instr & Resp Mask (62.54%).
    • CodeBench (Avg. Accuracy): Similar trends are observed with consistent gains. For LLaMA3.2-1B, KD (17.39%) -> LeaF Instr Mask (19.76%) -> LeaF Instr & Resp Mask (19.77%). For LLaMA3.2-3B, KD (31.89%) -> LeaF Instr Mask (33.44%) -> LeaF Instr & Resp Mask (34.53%). For Qwen2.5-Math-1.5B, KD (19.78%) -> LeaF Instr Mask (20.97%) -> LeaF Instr & Resp Mask (21.92%).
    • Multi-hop QA (Avg. Score): LeaF (Instr Mask) achieves 42.27% average, outperforming KD w/o Mask (39.03%) for LLaMA3.2-3B. This demonstrates that causal attention distillation via gradient-guided token pruning is more effective than standard distillation in enhancing reasoning capabilities.
  3. Benefits of Response-Level Pruning: Extending LeaF to include response-level pruning (LeaF (Instr & Resp Mask)) generally yields further performance improvements over instruction-level pruning (LeaF (Instr Mask)) in most tasks for both LLaMA and Qwen series. This suggests that distracting patterns can also emerge during the model's generation process, affecting subsequent tokens, and addressing them is beneficial. The paper hypothesizes that instruction-level and response-level distracting patterns may differ, and learning both can enhance reasoning.
  4. Absolute Gains: LeaF achieves an average absolute accuracy gain of 2.41%2.41\% on MathBench, 2.48%2.48\% on CodeBench, and 3.24%3.24\% on multi-hop QA benchmarks compared to standard KD. These are significant improvements for complex reasoning tasks.

6.2. Ablation Studies / Parameter Analysis

6.2.1. Masking Strategies Analysis

This analysis investigates the effectiveness of LeaF's gradient-based masking strategy by comparing it against two alternatives: Random Masking and PPL-based Masking.

Figure 7: Comparison of accuracy improvement with masking strategies over baseline (KD).
该图像是图表,展示了三种掩码策略(随机掩码、基于PPL掩码和基于梯度掩码)在不同数据集(GSM8K、MATH、OlympiadBench)及平均值上的准确率提升对比,基于梯度掩码策略表现最优。

Figure 7: Comparison of accuracy improvement with masking strategies over baseline (KD).

Results (Figure 7):

  1. Gradient-based Masking (LeaF) Superiority: LeaF's Gradient-based Masking consistently outperforms both Random Masking and PPL-based Masking. It shows the highest accuracy improvements, particularly on MATH-500 and OlympiadBench, which are more complex tasks.
  2. Random Masking Deterioration: Random Masking often leads to performance degradation (e.g., on GSM8K and OlympiadBench) or only slight improvements (on MATH-500). This indicates that indiscriminately masking tokens without an informed selection process can be detrimental to distillation performance, as it might remove critical information.
  3. PPL-based Masking Limitations: PPL-based masking (masking tokens with highest perplexity) offers modest improvements on simpler tasks like GSM8K and MATH-500 but performs comparably to Random Masking on the more complex OlympiadBench. This suggests that perplexity alone may not be sufficient for accurately identifying confounding tokens in challenging reasoning scenarios, highlighting the necessity of an advanced teacher model to guide token selection.

6.2.2. Response Splitting Strategies

The paper explores different strategies for processing responses, considering the presence of distracting patterns at both instruction and response levels.

Figure 8: Comparison of accuracy improvement with splitting strategies over baseline (KD).
该图像是图表,展示了图8中不同分割策略相较基线(KD)的准确率提升。纵轴为绝对提升百分比,不同颜色柱状对应未分割、分割为2段和3段的响应,显示在多个测试集上的表现差异。

Figure 8: Comparison of accuracy improvement with splitting strategies over baseline (KD).

Results (Figure 8):

  1. Response-level Pruning Benefits: Response-level pruning (both 2-segment and 3-segment splits) significantly outperforms instruction-level pruning (response without split). This underscores the importance of extending the confounding token detection and pruning beyond just the initial instruction to the model's generated response. This improvement is hypothesized to stem from the distinct nature of distracting patterns that appear during generation.
  2. Diminishing Returns of Further Segmentation: The performance of 3-segment splits is comparable to that of 2-segment splits. This suggests that beyond a certain point, additional segmentation at the response level yields diminishing returns. The hypothesis is that distracting patterns at the response level may exhibit certain regularities, and data generated by 2-segment splits is sufficient for the model to learn these patterns effectively.

6.2.3. Threshold Sensitivity Analysis

This analysis examines the impact of the confounding token threshold (τconfounder\tau_{\mathrm{confounder}}) on model performance.

Figure 9: Instruct-level in MathBench.
该图像是图表,展示了不同混淆阈值下,LLaMA3.2模型在MathBench任务中平均性能的对比,涉及KD和LeaF两种方法及不同模型规模。

Figure 9: Instruct-level in MathBench.

Figure 10: Response-level in MathBench.
该图像是图表,展示了在 MathBench 基准测试中,不同阈值下多种模型(包括 LLaMA3.2-KD 和 LLaMA3.2-LeaF)的平均性能百分比。图中散点和虚线对比了两种方法在不同干扰阈值下的表现。

Figure 10: Response-level in MathBench.

Results (Figure 9 and Figure 10):

  • Optimal Thresholds:

    • Instruct-level (Figure 9): LLaMA3.2-LeaF-1B performs best at τ=0.10\tau=0.10, while LLaMA3.2-LeaF-3B performs best at τ=0.05\tau=0.05.
    • Response-level (Figure 10): LLaMA3.2-LeaF-1B performs best at τ=0.15\tau=0.15, and LLaMA3.2-LeaF-3B at τ=0.10\tau=0.10.
  • Model Size and Sensitivity: Smaller models (LLaMA3.2-LeaF-1B) generally achieve optimal performance at higher misleading token thresholds than larger models (LLaMA3.2-LeaF-3B). This suggests that smaller models, being more susceptible to confounding tokens, benefit from more aggressive filtering (higher thresholds) to remove disruptive tokens more effectively.

  • Cross-Domain Stability: Detailed results in Appendix G (Table 7) further confirm that τconfounder=0.10\tau_{\mathrm{confounder}} = 0.10 generally yields stable performance across various tasks (math, code, QA), suggesting it's a robust default value.

    The following are the results from Table 7 of the original paper:

    Task Model τ=0.05 τ=0.10 τ=0.15
    MathBench LLaMA-3.2-1B-Instruct 32.43 34.40 33.17
    LLaMA-3.2-3B-Instruct 51.88 51.07 50.55
    CodeBench LLaMA-3.2-1B-Instruct 17.85 18.27 19.76
    LLaMA-3.2-3B-Instruct 32.83 33.55 32.95
    SQuAD 2.0 LLaMA-3.2-1B-Instruct 72.12 73.33 74.78
    LLaMA-3.2-3B-Instruct 88.67 89.44 83.66

Table 7 demonstrates that τ=0.10\tau=0.10 provides a good balance, often yielding the best or near-best performance across diverse tasks for different model sizes.

6.2.4. Collective Masking vs. Span Masking (Appendix C)

This comparison evaluates the two pruning strategies for creating counterfactual samples.

The following are the results from Table 3 of the original paper:

Model MATH-500
LLaMA3.2-1B-Instruct
Instruct Model (Pre-KD) 24.20
KD w/o Mask 34.00
Collective Pruning 34.20
Span Pruning 37.40
LLaMA3.2-3B-Instruct
Instruct Model (Pre-KD) 42.80
KD w/o Mask 50.00
Collective Pruning 49.20
Span Pruning 54.40

Results (Table 3):

  1. Span Pruning Superiority: Span Pruning substantially outperforms both Collective Pruning and the native knowledge distillation (KD w/o Mask). For LLaMA3.2-1B, Span Pruning (37.40%) significantly beats Collective Pruning (34.20%) and KD (34.00%). For LLaMA3.2-3B, Span Pruning (54.40%) also clearly surpasses Collective Pruning (49.20%) and KD (50.00%).
  2. Collective Pruning Degradation: Collective Pruning not only fails to improve performance but actually degrades it for the LLaMA3.2-3B-Instruct model (49.20% vs. 50.00% for KD). The paper attributes this to Collective Pruning disrupting sentence integrity by removing all confounding tokens simultaneously. This justifies the adoption of Span Pruning for all experiments in LeaF.

6.2.5. Ablation Study: Importance of Contrastive Pairs (Appendix F)

This ablation studies the impact of including student-wrong originals in the training data, which LeaF leverages as part of its contrastive signal.

The following are the results from Table 6 of the original paper:

Model GSM8K MATH-500 OlympiadBench Avg.
LLaMA 3.21B-Instruct
LeaF 57.70 35.40 10.09 34.40
w/o Student-Wrong Originals 58.15 34.80 7.42 33.46
LLaMA 3.23B-Instruct
LeaF 83.09 51.80 20.77 51.88
w/o Student-Wrong Originals 84.08 47.80 16.02 49.30

Results (Table 6):

  1. Value of Student-Wrong Originals: LeaF consistently outperforms the ablated variant (w/o Student-Wrong Originals) on more challenging datasets like MATH-500 and OlympiadBench. This indicates that retaining original samples where the student model initially fails due to misleading patterns (alongside counterfactual samples) provides a stronger contrastive signal. This signal helps the student downweight spurious correlations and learn to attend to causally relevant information.
  2. Reduced Generalization: Excluding student-wrong originals limits the model's exposure to problematic cases, thereby hindering its ability to generalize and robustly perform under confounding conditions.
  3. GSM8K Anomaly: The ablated variant shows a slight performance gain on GSM8K. This is attributed to GSM8K being a simpler dataset. After filtering, student-correct samples form a larger share of the training data, leading to a distributional bias towards easier problems, which the ablated model, focusing only on 'correct' examples, might temporarily benefit from.

6.2.6. Robustness Analysis (Appendix E)

This analysis evaluates LeaF's test-time robustness by assessing its performance on perturbed versions of the MathBench datasets (GSM8K, MATH-500, OlympiadBench). Perturbations are generated using back-translation to create realistic paraphrastic variations.

The following are the results from Table 5 of the original paper:

Model GSM8K MATH-500 OlympiadBench Avg.
LLaMA 3.21B-Instruct
Instruct (Pre-KD) 39.65 24.80 3.86 22.77
KD (no mask) 50.42 31.80 5.19 29.14
LeaF (Instr Mask) 51.10 32.00 6.53 29.88
LeaF (Instr & Resp Mask) 51.71 34.20 6.97 30.96
LLaMA 3.23B-Instruct
Instruct (Pre-KD) 70.43 40.80 9.64 40.29
KD (no mask) 74.00 48.20 14.09 45.43
LeaF (Instr Mask) 74.07 50.20 15.58 46.62
LeaF (Instr & Resp Mask) 76.12 51.20 19.88 49.07

Results (Table 5):

  1. Superior Robustness: LeaF consistently outperforms standard KD (KD (no mask)) under noisy conditions. This indicates that LeaF's pruning and attribution mechanisms are robust to moderate linguistic perturbations.
  2. Preserved Reasoning: The performance degradation of LeaF under noise remains within 23%2-3\%, suggesting that it effectively preserves its reasoning capability even when input distributions shift. These findings highlight LeaF's robustness and practical applicability in real-world scenarios involving noisy inputs or distributional shifts.

6.2.7. Computational Overhead Analysis (Appendix D)

This section details the runtime and memory analysis of LeaF compared to standard Knowledge Distillation (KD).

The following are the results from Table 4 of the original paper:

Model KD Runtime (h) LeaF Runtime (h) Overhead (%)
LLaMA-3.2-1B-Instruct 26.2 29.2 +11.5
LLaMA-3.2-3B-Instruct 23.7 26.2 +10.6
Qwen2.5-Math-1.5B 27.3 30.9 +13.3

Results (Table 4):

  1. Moderate Training Overhead: LeaF incurs an additional training time of approximately 1013%10-13\% compared to standard KD. For example, LLaMA-3.2-1B-Instruct training time increases from 26.2 hours to 29.2 hours (+11.5%).
  2. Justified Overhead: This moderate increase in training time is justified by the consistent 23%2-3\% absolute accuracy gains achieved across various benchmarks.
  3. Offline & Parallelizable Auxiliary Procedures: The auxiliary procedures of LeaF, including gradient computation, gradient normalization, span pruning, and counterfactual generation, are performed offline and are fully parallelizable.
    • Gradient computation: A one-time offline process taking about 3 hours for 7K samples on 8xNVIDIA A100 (80GB) GPUs.
    • Counterfactual generation: An offline process taking 50 minutes for smaller models (26K samples) and up to 2.85 hours for larger ones on 4xNVIDIA A100 (80GB) GPUs. These steps incur no extra cost during online training or inference, making LeaF practical and scalable.

6.2.8. Case Study in an Interpretable Perspective

A case study is presented (Figure 11) to illustrate how LeaF improves interpretability by enabling the model to focus on critical information and avoid confounding tokens, compared to standard knowledge distillation (KD).

该图像是论文中的示意图,展示了LeaF框架的两阶段流程:基于梯度与教师模型对比识别混淆token,再通过蒸馏阶段剪枝这些token,实现因果注意力的对齐。 该图像是论文中的示意图,展示了LeaF框架的两阶段流程:基于梯度与教师模型对比识别混淆token,再通过蒸馏阶段剪枝这些token,实现因果注意力的对齐。

The case study involves a mathematical problem: "Let aa be a positive real number such that all the roots of x3+ax2+ax+1=0x^3 + ax^2 + ax + 1 = 0 are real. Find the smallest possible value of aa."

Comparison (Figure 11):

  • LLaMA3.2-3B-Instruct (Distilled by KD) - Left (Incorrect):
    • The KD model starts by incorrectly attempting to apply the AM-GM inequality to the roots.
    • It misapplies AM-GM to potentially negative values (as roots can be negative, e.g., x=1x=-1 is a root), leading to an incorrect bound on aa (a ≤ 33\sqrt[3]{3}).
    • The model overlooks the critical constraint that "all roots must be real" and follows a flawed reasoning chain, highlighted in blue, resulting in an incorrect final answer.
  • LLaMA3.2-3B-Instruct (Distilled by LeaF) - Right (Correct):
    • The LeaF model correctly identifies x=1x = -1 as an evident real root through factoring (Step 2: (x+1)(x2+(a1)x+1)(x+1)(x^2 + (a-1)x + 1)).

    • It then correctly applies the discriminant condition (Step 3 and 4) to ensure the quadratic factor (x2+(a1)x+1x^2 + (a-1)x + 1) also yields real solutions. This involves setting the discriminant Δ=(a1)240\Delta = (a-1)^2 - 4 \ge 0.

    • Solving the inequality (Step 5 and 6) correctly leads to a1a \le -1 or a3a \ge 3.

    • Considering the problem's constraint that aa is a positive real number, LeaF correctly deduces that the smallest possible value for aa is 3.

    • The pink highlights mark the areas where LeaF correctly focuses on critical information and follows a coherent reasoning chain.

      This case study visually demonstrates that LeaF guides the model to attend to the causally relevant parts of the problem (e.g., factoring, discriminant condition) and correctly interpret constraints, whereas standard KD fails to prevent the model from being misled by spurious patterns or incorrect mathematical heuristics. This highlights LeaF's ability to create a more interpretable and reliable reasoning model.

6.3. Preliminary Experiments

6.3.1. Gradient Heatmap Comparison (Appendix A.1, Figure 12)

This preliminary experiment compares gradient heatmaps of a small student model and a large teacher model to visualize differences in their attention to critical tokens.

Figure 11: Case study comparing LeaF and knowledge distillation (KD) performance on the MATH500. Top: Heatmap showing the attention score differences between KD and LeaF on instruction tokens. Dark b…
该图像是插图,展示了模型在数学题回答中的一个具体错误,其中公式显示为 33\sqrt[3]{3} 并被标记为错误。

Figure 12: A case study comparing the performance of the student model (LlaMa3.2-1b-Instruct) and the teacher model (LlaMa3.3-70b-Instruct) on the MATH task.

Results (Figure 12):

  • The teacher model (Llama3.3-70B-Instruct) successfully captures a key contextual relation, like "Five containers of blueberries can be traded for two zucchinis," showing focused attention on relevant tokens.
  • The student model (Llama3.2-1B-Instruct) exhibits dispersed attention, meaning its focus is spread out and not precisely aligned with the critical information. This observation motivates the core hypothesis: by pruning distracting patterns (confounding tokens), the student model can be guided to better focus on salient information, thereby enhancing its reasoning capabilities.

6.3.2. Performance Gains from Token Pruning (Appendix A.2, Figure 1)

This pilot study assesses the performance improvement of student models after simply removing distracting patterns (confounding tokens).

Figure 1: Accuracy improvements achieved by removing confounding tokens from small models on the math and code training corpora. The results demonstrate a significant increase in performance, with ov…
该图像是论文中图1,展示了去除困惑词后小模型在数学和代码训练语料上的准确率提升。结果显示数学语料提升超过20%,代码语料提升超过10%,体现了方法的有效性。

Figure 1: Accuracy improvements achieved by removing confounding tokens from small models on the math and code training corpora. The results demonstrate a significant increase in performance, with over 20%20 \% improvement on the math corpus and more than 10%10 \% on the code corpus. (For further details on these categories, see Appendix A.)

Results (Figure 1):

  • Substantial Gains: Simply pruning distracting patterns without any additional training yields substantial improvements.
    • Over 20%20\% improvement in average accuracy on the MATH training corpus.
    • More than 10%10\% improvement on the Code training corpus.
  • Impact on Complexity: Greater improvements are observed on AMC_AIME (more complex) compared to GSM8K (simpler). This suggests that complex reasoning problems tend to contain more distracting patterns that interfere with model inference. These findings strongly support the idea that mitigating the influence of distracting patterns is crucial for improving the robustness and accuracy of LLM reasoning.

A further representative case is presented in Figure 2:

Figure 2: Comparison of reasoning before and after pruning distracting patterns. Blue-shaded regions indicate pruned confounding tokens. Pink highlights mark areas that require focus, while blue high…
该图像是文本示意图,展示了细菌群落增长模式的逐步计算过程,重点突出在第11步计算出细菌数量1536,包含粉色和蓝色加粗文本强调关键内容。

Figure 2: Comparison of reasoning before and after pruning distracting patterns. Blue-shaded regions indicate pruned confounding tokens. Pink highlights mark areas that require focus, while blue highlights show where excessive attention caused errors.

Figure 2 illustrates a math problem about bacterial growth. Removing distracting patterns (blue-shaded regions) from the instruction helps the model focus on critical information and enhance reasoning. This visually reinforces the hypothesis that improving attention to relevant details is a promising direction.

6.3.3. Generation Quality Improvements from Token Pruning (Appendix A.3, Figure 13)

This analysis quantifies the similarity between student model outputs and teacher model outputs using Jaccard Similarity under two conditions: original instruction and instruction with pruned distracting patterns.

Figure 13: Jaccard similarity distribution between student model responses (original vs. instruction pruned distracting patterns) and ground-truth responses on math and code datasets.
该图像是图表,展示了数学和代码数据集中学生模型响应与真实答案之间Jaccard相似度的概率分布比较,比较了原始指令与去除误导模式后的效果及其拟合曲线。

Figure 13: Jaccard similarity distribution between student model responses (original vs. instruction pruned distracting patterns) and ground-truth responses on math and code datasets.

Results (Figure 13):

  • Improved Alignment: After removing confounding tokens from the instruction, there is a clear shift in the Jaccard Similarity distribution for student model responses on both code and math tasks. The distribution shifts towards higher similarity values, and the mode of the distribution increases. This indicates that by ignoring distracting patterns, the student model not only improves reasoning accuracy but also generates responses that are more aligned with the teacher model's outputs, thereby enhancing output quality and semantic coherence.

7. Conclusion & Reflections

7.1. Conclusion Summary

The paper introduces Learning to Focus (LeaF), a novel two-stage framework designed to enhance the reasoning consistency and interpretability of Large Language Models (LLMs). LeaF addresses the problem of LLMs being misled by distracting patterns or spurious correlations in their training data. By leveraging a causal analysis and a gradient-based pruning mechanism, LeaF effectively identifies confounding tokens that introduce these spurious correlations. In its first stage, it uses gradient-based comparisons between a teacher and student model to detect these tokens. In the second stage, it employs a hybrid distillation loss (combining standard and counterfactual distillation with span pruning) to align the student's attention with the teacher's focus on truly critical context tokens.

Experimental results across mathematical reasoning, code generation, and multi-hop question answering benchmarks consistently demonstrate that LeaF achieves significant absolute improvements in accuracy and robustness compared to standard knowledge distillation. Furthermore, case studies and ablation analyses confirm that LeaF successfully suppresses attention to confounding tokens, leading to a more interpretable and reliable reasoning model. The framework also shows benefits from response-level pruning, acknowledging that distracting patterns can appear throughout the generation process.

7.2. Limitations & Future Work

The authors identify the following limitations for LeaF:

  1. Dependence on an Advanced Teacher Model: The core mechanism of LeaF for identifying confounding tokens relies on the availability and superior performance of a teacher model. This means that if a strong teacher is not available, or if the teacher itself struggles with certain types of confounding, the effectiveness of LeaF might be limited.
    • Future Work: The authors suggest exploring self-improvement mechanisms where models can refine their attention to critical tokens and boost reasoning without relying on an external advanced model. This would increase the framework's autonomy and applicability.
  2. Limited Scalability to Long-Form Generation: The current validation of LeaF is primarily on math and code tasks, which, while complex, might not represent the full spectrum of long-form text generation (e.g., creative writing, detailed explanations, summarization of very long documents).
    • Future Work: Investigating LeaF's applicability to long-text generation and other diverse domains is proposed for future research.

7.3. Personal Insights & Critique

  • Strong Theoretical Grounding: The paper's adoption of a causal framework to explain and mitigate spurious correlations is a significant strength. Framing distracting patterns as confounders provides a principled, theoretically sound approach that distinguishes it from many heuristic-based token selection methods. This interventionist perspective is powerful for truly understanding and improving model behavior, rather than just optimizing performance.

  • Interpretability as a Core Benefit: Beyond performance gains, the emphasis on interpretability through attention visualization and case studies is highly valuable. Showing why a model makes a correct decision (by focusing on the right tokens) or an incorrect one (by being distracted) is crucial for building trust and for future model development. The contrastive case study in Figure 11 effectively illustrates this.

  • Practicality and Scalability: The analysis of computational overhead (Appendix D) demonstrates that while LeaF adds some training time, its auxiliary procedures are offline and parallelizable. This ensures that the method is practical for real-world application, offering significant accuracy gains for a manageable increase in training cost, with no extra cost at inference time.

  • Potential for Broader Application: While the paper notes limitations in long-form generation, the core idea of causal attention distillation could potentially be applied to many other domains where LLMs struggle with context or rely on superficial cues. Any task susceptible to spurious correlations in data could benefit from this approach, such as complex scientific text understanding, legal document analysis, or even multimodal tasks where certain visual or auditory elements might act as confounders.

  • The "Student-Wrong Originals" Insight: The ablation study highlighting the importance of including student-wrong originals is a subtle but critical insight. It shows that learning from failure cases, specifically when the student is misled by confounders and the teacher is not, provides a powerful contrastive signal that strengthens the student's ability to resist spurious patterns. This goes beyond simply learning from 'good' examples.

  • Further Refinements of Confounder Identification: While gradient-based sensitivity is a strong method, further research could explore more nuanced ways to define and identify confounding tokens. For instance, integrating linguistic features or semantic parse information into the gradient analysis might refine the detection. Also, the reliance on a threshold τconfounder\tau_{\mathrm{confounder}} still introduces a hyperparameter that requires careful tuning.

  • Self-Improvement Without a Teacher: The authors wisely point out the teacher dependence as a limitation. Developing self-improvement mechanisms that allow a model to internally identify and mitigate its own spurious correlations would be a groundbreaking advancement, making such frameworks more autonomous and potentially unlocking even greater capabilities.

    Overall, LeaF presents a rigorous and impactful contribution to improving LLM reasoning, offering a clear path toward more reliable, accurate, and understandable AI systems.

Similar papers

Recommended via semantic vector search.

No similar papers found yet.