Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning
TL;DR Summary
LeaF uses gradient-guided token pruning to remove confounding tokens, aligning student attention with causal focus from teachers, improving reasoning accuracy and interpretability across multiple benchmarks.
Abstract
Large language models (LLMs) have demonstrated significant improvements in contextual understanding. However, their ability to attend to truly critical information during long-context reasoning and generation still falls behind the pace. Specifically, our preliminary experiments reveal that certain distracting patterns can misdirect the model's attention during inference, and removing these patterns substantially improves reasoning accuracy and generation quality. We attribute this phenomenon to spurious correlations in the training data, which obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon may induce redundant reasoning processes, potentially resulting in significant inference overhead and, more critically, the generation of erroneous or suboptimal responses. To mitigate this, we introduce a two-stage framework called Learning to Focus (LeaF) leveraging intervention-based inference to disentangle confounding factors. In the first stage, LeaF employs gradient-based comparisons with an advanced teacher to automatically identify confounding tokens based on causal relationships in the training corpus. Then, in the second stage, it prunes these tokens during distillation to enact intervention, aligning the student's attention with the teacher's focus distribution on truly critical context tokens. Experimental results demonstrate that LeaF not only achieves an absolute improvement in various mathematical reasoning, code generation and multi-hop question answering benchmarks but also effectively suppresses attention to confounding tokens during inference, yielding a more interpretable and reliable reasoning model.
Mind Map
In-depth Reading
English Analysis
1. Bibliographic Information
1.1. Title
Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning
1.2. Authors
Yiju Guo, Wenkai Yang, Zexu Sum, Ning Ding, Zhiyuan Liu, Yankai Lin. The authors are affiliated with Gaoling School of Artificial Intelligence, Renmin University of China; Department of Computer Science and Technology, Tsinghua University; and Baidu Inc. Their research background primarily lies in natural language processing, large language models, and machine learning, with a focus on improving LLM reasoning, efficiency, and interpretability.
1.3. Journal/Conference
The paper is published on arXiv, a preprint server for scientific papers. While not a peer-reviewed journal or conference proceeding in its current form, arXiv is a widely used platform in the AI community for disseminating cutting-edge research rapidly. Papers posted on arXiv are often submitted to top-tier conferences (e.g., ACL, EMNLP, NeurIPS, ICML) or journals later. The publication date suggests it is recent research.
1.4. Publication Year
2025 (Published on 2025-06-09T15:16:39.000Z)
1.5. Abstract
This paper addresses the challenge of Large Language Models (LLMs) struggling to focus on truly critical information during long-context reasoning and generation, often misdirected by distracting patterns attributed to spurious correlations in training data. This leads to redundant reasoning, inference overhead, and erroneous outputs. To mitigate this, the authors introduce Learning to Focus (LeaF), a two-stage framework. The first stage employs gradient-based comparisons with an advanced teacher model to automatically identify confounding tokens based on causal relationships. In the second stage, these tokens are pruned during distillation to align the student model's attention with the teacher's focus on critical context tokens. Experimental results demonstrate that LeaF achieves significant absolute improvements across various mathematical reasoning, code generation, and multi-hop question answering benchmarks. Furthermore, it effectively suppresses attention to confounding tokens, resulting in a more interpretable and reliable reasoning model.
1.6. Original Source Link
Official Source: https://arxiv.org/abs/2506.07851 PDF Link: https://arxiv.org/pdf/2506.07851v2.pdf Publication Status: This is a preprint on arXiv.
2. Executive Summary
2.1. Background & Motivation
The core problem the paper aims to solve is the inability of Large Language Models (LLMs) to consistently focus on truly critical information during complex, long-context reasoning and generation tasks. Despite LLMs' advanced capabilities in contextual understanding and language generation, the authors' preliminary experiments reveal that distracting patterns in the input can misdirect the model's attention. When these distracting patterns are removed, reasoning accuracy and generation quality substantially improve.
This problem is important because spurious correlations within the vast training data of LLMs can obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon leads to several critical issues:
-
Redundant Reasoning Processes: LLMs might engage in unnecessary computational steps, increasing
inference overhead. -
Erroneous or Suboptimal Responses: The misdirected attention can cause the model to generate incorrect or less-than-ideal outputs.
-
Lack of Interpretability: Models that rely on spurious correlations are harder to understand and debug.
The paper's entry point or innovative idea is to address these issues by adopting a
causal perspective. It views distracting patterns asspurious confoundersthat interfere with the LLM's reasoning. By identifying and mitigating the influence of these confounders, the paper aims to guide LLMs tofocuson the truly relevant information, thereby improving their robustness, accuracy, and interpretability.
2.2. Main Contributions / Findings
The paper's primary contributions are:
-
Introduction of Learning to Focus (LeaF) Framework: A novel two-stage framework that leverages
intervention-based inferenceto disentangleconfounding factorsin LLM reasoning. -
Gradient-Guided Confounding Token Detection: LeaF introduces a method to automatically identify
confounding tokensby comparinggradient sensitivitiesbetween a high-capacity teacher model and a student model. This approach is rooted incausal relationshipsidentified within the training corpus. -
Causal Attention Distillation: LeaF incorporates
span pruningof identified confounding tokens during ahybrid distillationprocess. This aligns the student model's attention with the teacher'sfocus distributionon critical context tokens, effectively enacting a causal intervention. -
Demonstrated Performance Improvements: LeaF achieves significant absolute performance gains across a diverse set of benchmarks:
- Mathematical Reasoning: Absolute improvements in tasks like GSM8K, MATH-500, and OlympiadBench.
- Code Generation: Absolute improvements on HumanEval+, LeetCode, and LivecodeBench.
- Multi-hop Question Answering: Absolute improvements on HotpotQA, 2WikiMultiHopQA, and Musique.
-
Enhanced Interpretability and Reliability: Experimental results, including attention visualizations and case studies, show that LeaF suppresses attention to confounding tokens, leading to a more interpretable and reliable reasoning model by encouraging focus on causally relevant information.
-
Response-Level Pruning Strategy: The paper shows that extending pruning from instruction-level to response-level further enhances performance, indicating that distracting patterns exist and influence generation at both stages.
These findings collectively demonstrate that explicitly mitigating spurious correlations through causal attention distillation is a highly effective approach to improve LLM reasoning capabilities, robustness, and interpretability.
3. Prerequisite Knowledge & Related Work
3.1. Foundational Concepts
To understand this paper, a reader should be familiar with the following fundamental concepts:
- Large Language Models (LLMs): These are advanced artificial intelligence models trained on vast amounts of text data to understand, generate, and process human language. They typically employ transformer architectures with billions of parameters, enabling them to perform a wide range of tasks like text generation, translation, summarization, and question answering. Examples include GPT-3, LLaMA, and Qwen.
- Contextual Understanding & Long-Context Reasoning: This refers to an LLM's ability to grasp the meaning of information within a given
context(e.g., a long document, a conversation history) and use that understanding to perform reasoning tasks.Long-context reasoningspecifically deals with contexts that are very long, posing challenges for LLMs to maintain focus and integrate information from disparate parts of the input. - Attention Mechanism: A core component of transformer-based LLMs. It allows the model to weigh the importance of different parts of the input sequence when processing each token. For example, when generating a word, the attention mechanism determines which other words in the input (or previously generated output) are most relevant. The foundational concept is
self-attentionwhere a token attends to all other tokens in the sequence.- The
Attentionmechanism computes a weighted sum ofvaluevectors, where the weight assigned to each value is determined by the similarity between thequeryvector (representing the current token) andkeyvectors (representing other tokens). - $
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$
- : Query matrix, derived from the input embeddings, representing the current token's focus.
- : Key matrix, derived from the input embeddings, representing other tokens' potential relevance.
- : Value matrix, derived from the input embeddings, containing the information to be aggregated.
- : Dot product between query and key matrices, calculating similarity scores.
- : Scaling factor, where is the dimension of the key vectors, used to prevent large dot products from pushing the softmax function into regions with tiny gradients.
- : Normalizes the similarity scores into probability distributions, ensuring weights sum to 1.
- : The value vectors, which are weighted and summed to produce the output of the attention layer.
- The
- Spurious Correlations: These are statistical relationships observed in data that appear to be causal but are not. In the context of LLMs,
spurious correlationsmight arise when the model learns to associate certain superficial patterns orconfounding tokensin the training data with specific outputs, even though these patterns do not represent a true causal link to the desired response. For example, if a dataset frequently includes a specific boilerplate phrase alongside correct answers, an LLM might learn to rely on the boilerplate rather than the actual reasoning content. - Knowledge Distillation (KD): A technique where a smaller, "student" model is trained to mimic the behavior of a larger, more powerful "teacher" model. The goal is to transfer the knowledge from the teacher to the student, often resulting in a smaller, faster, and more efficient student model that retains much of the teacher's performance. This typically involves minimizing the difference between the teacher's and student's output distributions (e.g., using
Kullback-Leibler divergence).- Kullback-Leibler (KL) Divergence: A non-symmetric measure of the difference between two probability distributions. It quantifies how much one probability distribution diverges from a second, expected probability distribution.
- $ D_{\mathrm{KL}}(P \parallel Q) = \sum_{i} P(i) \log\left(\frac{P(i)}{Q(i)}\right) $
- : The true or reference probability distribution (e.g., teacher's output).
- : The approximate or learned probability distribution (e.g., student's output).
- : Sums over all possible outcomes .
P(i),Q(i): Probabilities of outcome under distributions and respectively.- A lower KL divergence indicates greater similarity between the two distributions.
- Kullback-Leibler (KL) Divergence: A non-symmetric measure of the difference between two probability distributions. It quantifies how much one probability distribution diverges from a second, expected probability distribution.
- Causal Inference / Structural Causal Models (SCM) / Directed Acyclic Graphs (DAGs): A framework developed by Judea Pearl for reasoning about cause-and-effect relationships. An
SCMconsists of a set of variables and structural equations that describe how each variable is determined by its direct causes. ADAGvisually represents these causal relationships, with nodes as variables and directed edges as causal influences. In this paper,DAGsare used to model how input tokens (), confounding tokens (), and the model's output () interact. - Gradients in Deep Learning: In machine learning,
gradientsrepresent the rate of change of a function (e.g., a loss function) with respect to its inputs (e.g., model parameters, input embeddings). They are central to optimization algorithms like stochastic gradient descent. In this paper,gradient sensitivityis used to measure how much a change in a specific input token affects the model's output or loss, indicating the token's importance or influence. - Counterfactuals: In causal inference, a
counterfactualdescribes what would have happened if a particular event had not occurred, or if an intervention had been made. In this paper,counterfactual samplesare created by pruning (removing) identifiedconfounding tokensfrom an instruction. By comparing the model's behavior on original vs. counterfactual samples, the framework aims to isolate and learn true causal dependencies.
3.2. Previous Works
The paper positions itself within the context of several lines of research:
-
Chain-of-Thought (CoT) Knowledge Distillation: This paradigm aims to transfer complex reasoning abilities from large teacher models to smaller student models, often by distilling the step-by-step reasoning (
Chain-of-Thought) rather than just the final answer.- CoTKD [41]: Pioneered using teacher-generated CoT explanations to fine-tune student models, enhancing their reasoning.
- Data-focused approaches: These methods improve distillation by enhancing the quality and diversity of training data. Examples include:
CD [13](Counterfactual Distillation): Reasoning through counterfactuals for smaller models.SCORE [56]: Small language models need strong verifiers to self-correct reasoning.Skintern [28]: Internalizing symbolic knowledge for distilling better CoT capabilities.
- Model-focused approaches: These focus on improving model architectures and inference strategies for efficiency and reasoning. Examples include:
PRR [41](Problem-Reduction Reasoning): A method that breaks down complex problems.ATM [7](Adaptive Thinking Models): Distilling reasoning ability with adaptive thinking.
- Differentiation: While these methods primarily focus on output imitation and improving data/model structures, LeaF explicitly distills the teacher's ability to focus on
critical informationduring inference, enabling context-aware reasoning by targetingspurious correlations.
-
Critic Token Identification: This area explores strategies to identify and mitigate redundant or less important tokens during LLM reasoning.
LLMLingua [23]: Relies on model's self-assessment for token importance, potentially introducing biases.RHO-1 [29]: Introduces Selective Language Modeling (SLM) during Supervised Fine-Tuning to prioritize informative tokens.TokenSkip [47]: Selectively skips low-impact tokens in CoT stages to compress reasoning paths.cDPO [30]: Enhances Direct Preference Optimization by isolating critical tokens through contrastive learning.- Differentiation: Most existing methods concentrate on
output tokensorreasoning steps, overlooking the influence ofprior contextual information from the instruction phase. LeaF, in contrast, uses an advanced teacher model to identifyconfounding tokensbased ongradient differencesin both instructions and responses, establishing stronger connections between instruction and output for more holistic filtering.
-
Reasoning Consistency: This research line focuses on improving the stability and reliability of LLM outputs, especially in multi-step reasoning.
Self-Consistency [46]: Samples multiple reasoning chains and uses majority voting to stabilize final answers.Adaptive Consistency [2]andEarly Scoring Self-Consistency [27]: Introduce stopping criteria to reduce inference costs.Reasoning Aware Self-Consistency [45]: Weights sample quality and reasoning-path importance.- Differentiation: These methods primarily aim to stabilize the
final answerandreasoning paths. LeaF, however, definesconsistencyas both answer stability andcontext adherence. It uses a distillation strategy to systematically suppressmisleading signals, enhancing the student'sfocus on key contextual informationand promoting more robust context adherence throughout the generation process.
3.3. Technological Evolution
The field of LLM reasoning has evolved from:
-
Basic Language Models: Focused on next-token prediction, often struggling with complex reasoning.
-
Emergence of Transformers and Attention: Revolutionized contextual understanding.
-
Chain-of-Thought (CoT) Prompting: Enabled LLMs to perform multi-step reasoning by generating intermediate steps.
-
Knowledge Distillation for Reasoning: Applied KD to transfer CoT capabilities from large to small models.
-
Focus on Efficiency and Robustness: Development of methods for token pruning, selective modeling, and consistency techniques to optimize performance and reliability.
This paper's work (
LeaF) fits into the latest stage of this evolution by addressing a critical, often overlooked aspect: the intrinsic issue ofspurious correlationsin training data leading tomisdirected attention. It moves beyond simply mimicking teacher outputs or pruningless importanttokens to specifically identify and intervene oncausal confounders, thereby enhancing the fundamentalattention mechanismandcausal reasoningof student models.
3.4. Differentiation Analysis
Compared to existing methods, LeaF offers several core differences and innovations:
- Causal Perspective on Confounding Factors: Unlike many previous methods that focus on redundant or less informative tokens, LeaF explicitly frames
distracting patternsasspurious confoundersin acausal framework. This theoretical grounding provides a principled way to identify and mitigate their influence. - Gradient-Guided Teacher-Student Comparison: LeaF uses
gradient sensitivityfrom both a powerful teacher model and a student model to identifyconfounding tokens. This goes beyond perplexity-based or self-assessment methods, leveraging the teacher's superior reasoning to pinpoint where the student's attention is spuriously drawn. - Intervention-Based Distillation with Counterfactuals: The framework generates
counterfactual samplesbyspan pruningthe identified confounding tokens. Thehybrid distillation lossthen explicitly aligns the student's attention by contrasting its behavior on original and counterfactual samples, forcing it to learngenuine causal dependencies. This is a direct intervention on the causal graph. - Holistic Token Filtering (Instruction + Response): While some methods focus on output tokens, LeaF considers
confounding tokensin bothinstruction-levelandresponse-levelcontexts, providing a more comprehensive approach to guide attention throughout the entire reasoning and generation process. - Focus on Attention Alignment, not just Output Imitation: The core goal is to
align the student's attention with the teacher's focus distribution on truly critical context tokens, rather than merely imitating the teacher's final output or intermediate steps. This aims to improve the underlyingreasoning processitself.
4. Methodology
The Learning to Focus (LeaF) framework aims to mitigate the adverse effects of distracting patterns (termed confounding tokens) on LLM reasoning by explicitly identifying and intervening on them. It leverages a causal perspective to guide a student model's attention toward truly critical information, drawing knowledge from a more advanced teacher model.
4.1. Principles
The core idea of LeaF is rooted in causal inference, specifically Pearl's Structural Causal Model (SCM) [39]. The method posits that confounding tokens introduce spurious correlations that mislead the LLM's reasoning process. By intervening on these confounders—effectively removing their influence—the student model can learn to focus on authentic causal instruction-response relationships. This is achieved through gradient-based comparisons to detect confounders and a hybrid distillation strategy that encourages the student to mirror the teacher's attention to causally relevant information. The theoretical basis is that by blocking non-causal influences from confounding tokens, the student model learns attention patterns grounded in the true causal structure, leading to improved robustness and interpretability.
4.2. Core Methodology In-depth (Layer by Layer)
4.2.1. Causal Framework
The paper models the reasoning process using a Directed Acyclic Graph (DAG).
As shown in Figure 3, the input tokens influence the model's output . A subset of these input tokens, denoted as confounding tokens , introduces spurious correlations. These confounders simultaneously influence the true causal input (represented as in the formula below, where is a specific part of the input excluding ) and the output , thereby distorting the observed relationship between and .

该图像是论文中展示的图3因果图示意图,描述了推理过程中输入提示 和模型输出 之间的因果关系。图中标注的混淆令牌 引入了干扰的伪相关,该方法通过检测并屏蔽 ,消除 到 的伪因果边,恢复真实因果依赖。
Figure 3: Causal graph of the reasoning process. represents the input prompt, and denotes the model's output. A subset of tokens in , identified as confounding tokens ( A ) ,introduces spurious correlations that disrupt the reasoning process. Our method detects and masks , effectively eliminating the spurious edge from to and restoring the true causal dependency.
The observed conditional distribution, which includes the biasing effect of confounding tokens , is given by:
$
\begin{array} { l } { P ( Y \mid X _ { i } = x ) = } \ { \displaystyle \sum _ { A } P ( Y \mid X _ { i } = x , A ) P ( A \mid X _ { i } = x ) , } \end{array}
$
-
: The probability of observing output given an input . This is the
observed conditional distributionthat LLMs typically learn. -
: A specific input (or part of the input) instance .
-
: The set of
confounding tokens. -
: The probability of output given both the input and the presence of
confounding tokens. -
: The probability of
confounding tokensbeing present given the input . -
: Summation over all possible states or values of the
confounding tokens.This formula shows that the observed relationship between and is an average over the influence of , meaning introduces an
indirect influenceon that is not part of thetrue causal instruction-response relationship. Thisobserved conditional distributiondeviates from theinterventional distribution, which represents the causal effect of on if the influence of were explicitly removed (i.e., means setting to while holding other factors constant, including breaking spurious links).
To address this bias, the paper proposes causal pruning, which involves detecting and masking the confounding tokens . This effectively eliminates the spurious edge from to (dashed arrow in Figure 3), encouraging the student model to learn attention patterns based on the true causal structure.
4.2.2. LeaF: Learning to Focus Framework
The LeaF framework, illustrated in Figure 4, consists of two main stages:

该图像是示意图,展示了论文中提出的两阶段框架“Learning to Focus (LeaF)”的流程。第一阶段通过梯度比较识别混淆因子并进行裁剪;第二阶段进行因果注意力蒸馏,令学生模型学习聚焦重要上下文,提升推理准确性和模型可靠性。
Stage 1: Confounding Token Detection
Figure 4: Method Overview. The training pipeline comprises two key stages: (1) Confounding Token Detection: gradient-based comparisons between an advanced teacher model and the student model are used to identify confounding tokens in the training samples and constructs counterfactual samples by pruning these tokens; and (2) Causal Attention Distillation: prune identified confounders respectively during training to align the student's attention with the teacher's and capture casual relationships. This targeted intervention steers the model toward actual causal dependencies, enhancing both robustness and interpretability.
- Confounding Token Detection (Section 2.2.1): In this stage,
confounding tokensare identified usinggradient-based comparisonsbetween ateacher model() and astudent model(). Once identified,counterfactual samplesare constructed bypruningthese tokens. - Causal Attention Distillation (Section 2.2.2): This stage uses a
hybrid distillation lossthat operates on bothoriginalandcounterfactual samples. This loss aligns the student's output distribution and attention with the teacher's, specifically encouraging the student to capturecausal dependenciesby learning from the teacher's behavior before and after theconfounding tokenintervention.
4.2.2.1. Confounding Token Detection
To identify confounding tokens , LeaF employs a gradient-based approach [42, 43] to quantify the influence of each token on the model's output . This process focuses on instances where the student model makes an incorrect prediction, but the more capable teacher model correctly handles it, indicating a potential misdirection of the student's attention by confounding tokens.
For each token , the gradient sensitivity (absolute gradient of the loss with respect to the token embedding) is calculated for both the teacher and student models:
$
g _ { i } ^ { ( T ) } = \left| \frac { \partial \ell ( x _ { i } \mid X ; \pmb { \theta } _ { T } ) } { \partial x _ { i } } \right| , \quad g _ { i } ^ { ( S ) } = \left| \frac { \partial \ell ( x _ { i } \mid X ; \pmb { \theta } _ { S } ) } { \partial x _ { i } } \right| .
$
-
: The absolute gradient sensitivity of the teacher model to token .
-
: The absolute gradient sensitivity of the student model to token .
-
: The loss function of the model (teacher or student ) with respect to the input token within the context . This loss is typically calculated between the model's predicted logits and the
gold reference(true answer). -
: The partial derivative of the loss with respect to the embedding of token . This measures how sensitive the loss is to changes in .
-
: Absolute value, indicating the magnitude of sensitivity regardless of direction.
To allow for a fair comparison between models with potentially different
gradient scales, these sensitivity values are normalized usingmin-max normalization: $ \hat { g } _ { i } ^ { ( T ) } = \frac { g _ { i } ^ { ( T ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( T ) } } { \mathrm { m a x } _ { j } g _ { j } ^ { ( T ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( T ) } } , \quad \hat { g } _ { i } ^ { ( S ) } = \frac { g _ { i } ^ { ( S ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( S ) } } { \mathrm { m a x } _ { j } g _ { j } ^ { ( S ) } - \mathrm { m i n } _ { j } g _ { j } ^ { ( S ) } } . $ -
, : The normalized gradient sensitivities for token for the teacher and student models, respectively. These values will be between 0 and 1.
-
, : The minimum and maximum gradient sensitivities across all tokens in the input for the teacher model.
-
, : The minimum and maximum gradient sensitivities across all tokens in the input for the student model.
To identify
confounding tokens, the difference in normalized gradient sensitivity between the teacher and student is computed for each token: $ \Delta \hat { g } _ { i } = \hat { g } _ { i } ^ { ( T ) } - \hat { g } _ { i } ^ { ( S ) } . $ -
: The difference in normalized gradient sensitivity for token . A negative value indicates that the student model is more sensitive to than the teacher, while a positive value indicates the opposite.
A token is classified as a
Confounding Tokenif two conditions are met:
-
Sensitivity Discrepancy: It receives significant attention from the student model but negligible attention from the teacher during inference. This is formalized by normalizing the gradient difference and comparing it to a threshold: $ \frac { \Delta \hat { g } _ { i } - \operatorname* { m i n } _ { j } \Delta \hat { g } _ { j } } { \operatorname* { m a x } _ { j } \Delta \hat { g } _ { j } - \operatorname* { m i n } _ { j } \Delta \hat { g } _ { j } } \leq \tau _ { \mathrm { c o n f o u n d e r } } , $
- : Min-max normalized value of the gradient difference . This ensures the threshold works consistently across different instances.
- : A predefined
threshold(determined via validation set analysis). A low value for this normalized difference indicates that is small or negative, meaning the teacher is less sensitive to this token than the student, or even actively ignores it. - Intuitively, this condition identifies tokens where the student model's attention is disproportionately high compared to the teacher, suggesting a
spurious dependency.
-
Correct Prediction After Removal: The removal of token results in correct predictions from both the teacher and student models. This ensures that the identified token is indeed
confoundingand notcausally essentialfor correct reasoning.Pruning Strategies: The paper explores two strategies for constructing
counterfactual samplesby removing identifiedconfounding tokensfrom the instruction :
该图像是图示图,展示了论文中图5所示的集体剪枝(Collective Pruning)和区间剪枝(Span Pruning)方法。通过不同颜色区域标识有效token、混淆token及干扰pattern区域,说明剪枝操作在训练语料中的具体应用。
Figure 5: Illustration of Collective Pruning and Span Pruning.
-
Collective Pruning: Removes the entire set of all identified
confoundersfrom , resulting in . (See Figure 5, right side, top example: all blue shaded tokens removed simultaneously). -
Span Pruning: Removes only
one contiguous spanofconfounding tokensat a time, yielding . (See Figure 5, right side, bottom examples: only a single blue shaded span is removed in each counterfactual).Preliminary experiments (Appendix C) show that
Span PruningoutperformsCollective Pruning, as removing alldistracting patternssimultaneously can disrupt thesentence integrity. Therefore, LeaF adoptsSpan Pruningto constructcounterfactual samples: $ \mathcal { D } _ { \mathrm { p r u n e d } } = \left{ ( X \setminus A _ { i } , y ) \right} _ { i = 1 } ^ { k } , $
-
: The set of
counterfactual samplesgenerated. -
: A single
counterfactual sample, where a distinctconfounding spanhas been removed from the original instruction , and is the corresponding ground-truth output. -
: The total number of
distinct confounding spansidentified in the original instruction .This augmentation creates multiple
counterfactual examplesfor a single original instruction, encouraging the model to learn reasoning paths that areinvariantto specificconfounders.
4.2.2.2. Causal Attention Distillation
After generating both original and counterfactual samples, LeaF optimizes two complementary distillation objectives to steer the student model towards true causal dependencies:
-
Standard Distillation (): This objective aligns the
student's output distribution() with theteacher's output distribution() on theoriginal instructions. This is a typicalknowledge distillationloss. $ { \mathcal { L } } _ { k d } = D _ { \mathrm { K L } } \big ( p _ { T } ( y \mid X ) \parallel p _ { S } ( y \mid X ) \big ) , $- : The
Kullback-Leibler divergence(explained in Section 3.1). - : The probability distribution of the
teacher model's outputgiven theoriginal instruction. - : The probability distribution of the
student model's outputgiven theoriginal instruction.
- : The
-
Counterfactual Distillation (): This objective aligns the
student's output distribution() with theteacher's output distribution() on thecounterfactual instructions(whereconfoundershave been pruned). This is the keycausal interventioncomponent. $ { \mathcal { L } } _ { c d } = D _ { \mathrm { K L } } \big ( p _ { T } ( y \mid X \setminus A ) \parallel p _ { S } ( y \mid X \setminus A ) \big ) . $-
: The probability distribution of the
teacher model's outputgiven thecounterfactual instruction. -
: The probability distribution of the
student model's outputgiven thecounterfactual instruction.These two objectives are combined into a
hybrid distillation lossusing aweighting factor: $ \mathcal { L } = \lambda \mathcal { L } _ { k d } + \left( 1 - \lambda \right) \mathcal { L } _ { c d } , $
-
- : The total
hybrid distillation lossthat the student model minimizes during training. - : A hyperparameter in the range
[0, 1]that controls thetrade-offbetweenStandard DistillationandCounterfactual Distillation.-
If , only
Standard Distillationis used. -
If , only
Counterfactual Distillationis used. -
Values between 0 and 1 allow balancing both objectives.
This composite loss encourages the student to preserve
semantic knowledgefrom the teacher on original inputs while simultaneously enforcinggenuine causal dependenciesby observing the teacher's behavior whenconfounding factorsare removed.
-
Response Splitting Strategies: Beyond instruction-level pruning, the paper also considers applying the confounding token detection and pruning to the model's generated response. This is important because previously generated tokens can also act as contextual input for subsequent generation steps and might contain misleading patterns.

该图像是一个示意图,展示了图6中不同的回复拆分策略,包括语言CoT、指令级剪枝和响应级剪枝(2段和3段拆分)。图中用白色高亮表示输入部分,蓝色下划线表示用于计算交叉熵损失的输出部分。
Figure 6: Illustration of Response Splitting Strategies: Language CoT, Instruct-level Pruning, and Response-level Pruning (2-segment and 3-segment splits). Highlighted white areas represent the input, and blue underlined areas represent the outputs used for cross-entropy loss computation.
The paper considers two variants for response splitting:
- Instruct-level Pruning:
Confounding tokensare detected and pruned only in theinstructions. The LeaF framework is applied to theseinstruction-level pruned samples. (Figure 6, middle) - Both Instruct- and Response-level Pruning:
Confounding tokensare detected and pruned in both theinstructionsandpreceding generations(the model's partial output). This helps the model produce more accuratecontinuations. (Figure 6, right)-
2-segment splits: The response is split into two segments, and confounding tokens are identified and pruned within each.
-
3-segment splits: The response is split into three segments, allowing for more granular detection and pruning of confounding tokens.
This approach aims to address the dynamic nature of
distracting patternsthat can emerge not just in the initial prompt but also during the model's own generation process, further enhancingreasoning capabilities.
-
5. Experimental Setup
5.1. Datasets
The experiments evaluate LeaF on mathematical reasoning, code generation, and multi-hop question answering tasks.
Training Datasets:
- Mathematical Reasoning:
NuminaMath-CoT [26]: A large dataset used to ensure models encounter an equal number ofconfounding tokensacross tasks. The training set consists of 30k instances, randomly selected (7.5k each) from:Olympiads [16]: Advanced math and physics problems, filtered for pure-text tasks.AMC_AIME [26]: American Mathematics Competitions and American Invitational Mathematics Examination problems.GSM8K [8]: Grade-school level math word problems.MATH [17]: Competition-style math problems across various topics.
- Code Generation:
AceCode-87K [53]: A dataset for competitive coding, from which a subset of 120k instances is randomly selected.
- Multi-hop Question Answering:
KILT [40]datasets provided inHelmet [52]: A benchmark for knowledge-intensive language tasks. The training set is constructed by merging data from:HotpotQA [50]: Requires reasoning over multiple Wikipedia paragraphs.NQ [1]: Natural Questions, typically single-paragraph answers.PopQA [36]: Factoid questions.
- Totaling 3k annotated samples, drawn equally from HotpotQA, NQ, and PopQA, where each query is linked to
gold passagescontaining the answers.
Evaluation Datasets:
- Mathematical Reasoning:
GSM8K [8]: 8500 grade-school-level word problems, requiring 2-8 steps of basic arithmetic. Assesses multi-step reasoning in natural language.MATH-500 [17]: A subset of the MATH dataset (12500 competition-style problems) focusing on diverse mathematical topics.OlympiadBench [16]: A challenging benchmark for Olympiad-level math and physics problems. The text-only subset of 674 tasks is used for advanced symbolic reasoning.
- Code Generation:
HumanEval+ [32]: Extends the original HumanEval with additional Python programming tasks and augmented unit tests, targeting functional correctness.LeetCode [9]: Samples real-world algorithmic challenges (arrays, trees, dynamic programming) to assess the ability to generate correct and efficient solutions.LiveCodeBench (v4) [22]: A large-scale suite of real-world coding tasks with comprehensive unit tests and human preference annotations, for functional accuracy and coding style.
- Multi-hop Question Answering:
HotpotQA [50]: Requires reasoning over multiple Wikipedia paragraphs to answer complex questions.2WikiMultiHopQA [19]: Multi-hop questions generated from Wikipedia, requiring reasoning across at least two documents.Musique [44]: Provides multi-hop questions with fine-grained decomposition and supporting evidence, testing compositional reasoning and factual consistency.
5.2. Evaluation Metrics
For every evaluation metric, the following structure is provided:
-
Accuracy:
- Conceptual Definition: Accuracy measures the proportion of correctly predicted instances out of the total number of instances. It is a common metric for classification tasks where the output is a discrete value. In reasoning tasks, it typically refers to the percentage of problems for which the model provides the exact correct answer.
- Mathematical Formula: $ \text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} $
- Symbol Explanation:
Number of Correct Predictions: The count of instances where the model's output matches the ground truth.Total Number of Predictions: The total count of instances evaluated.
-
Pass@k:
- Conceptual Definition: Used in code generation tasks,
pass@kmeasures the probability that at least one of generated code samples passes the unit tests. If a model generates solutions for a problem, and at least one is functionally correct, the problem is considered solved. It accounts for the stochastic nature of code generation and how many attempts are needed to get a correct solution. The common variant used in the paper ispass@1for HumanEval+ and LeetCode, andpass@10for LivecodeBench. - Mathematical Formula: For a given problem, let be the number of test cases, be the number of generated code samples, and be the probability that a single sample passes. $ \text{Pass@k} = \frac{1}{\text{num_problems}} \sum_{problem=1}^{\text{num_problems}} \left(1 - \frac{\binom{C-P_{correct}}{k}}{\binom{C}{k}}\right) $ Note: The common formulation for pass@k for competitive programming tasks is typically simpler in evaluation, where samples are generated and if any one of them passes, the problem is counted as solved. The formula presented here is a more rigorous statistical estimation. Given the paper uses "pass@1" and "pass@10", it implies the simpler interpretation where attempts are made. A more common and simpler way to calculate it, if you generate samples and check each, is the empirical success rate. $ \text{Pass@k} = \frac{\text{Number of problems with at least one correct sample}}{\text{Total number of problems}} $ For a more robust statistical definition typically used (when samples are generated, and a 'correct' count is known): Let be the number of solutions generated for a problem, and be the number of those solutions that are correct. The probability that none of randomly chosen solutions (from the generated) are correct is . Thus, the probability that at least one is correct is . Pass@k is the average of this probability over all problems. $ \text{Pass@k} = \frac{1}{\text{num_problems}} \sum_{problem=1}^{\text{num_problems}} \left(1 - \frac{\binom{\text{generated_samples} - \text{correct_samples}}{\text{k}}}{\binom{\text{generated_samples}}{\text{k}}}\right) $
- Symbol Explanation:
num_problems: Total number of coding problems in the benchmark.generated_samples: Total number of code samples generated for a particular problem.correct_samples: Number of functionally correct samples amonggenerated_samplesfor a particular problem.- : The number of samples considered (e.g., 1 or 10).
- : The binomial coefficient, representing "n choose k", i.e., the number of ways to choose items from a set of items.
- Conceptual Definition: Used in code generation tasks,
-
EM (Exact Match):
- Conceptual Definition:
Exact Matchis a strict metric used in question answering that measures whether the model's generated answer string is character-for-character identical to any of the ground-truth answer strings. It's often used for tasks with short, factual answers. - Mathematical Formula: $ \text{EM} = \frac{1}{\text{N}} \sum_{i=1}^{N} \mathbb{I}(\text{model_answer}_i == \text{gold_answer}_i) $
- Symbol Explanation:
- : Total number of questions.
- : The indicator function, which is 1 if the condition inside is true, and 0 otherwise.
- : The answer generated by the model for question .
- : The ground-truth answer for question . (Often there are multiple valid gold answers; if the model matches any, it's correct).
- Conceptual Definition:
-
F1 Score:
- Conceptual Definition: The
F1 scoreis a measure of a test's accuracy, often used for question answering tasks, especially when answers can be phrases or spans of text. It is the harmonic mean ofprecisionandrecall.Precisionmeasures how many of the model's predicted tokens are correct (i.e., also in the gold answer).Recallmeasures how many of the ground-truth answer tokens were captured by the model's prediction. The F1 score provides a single metric that balances both precision and recall, being high only when both are high.
- Mathematical Formula: $ F1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}} $ where, for text spans: $ \text{Precision} = \frac{\text{Number of overlapping tokens}}{\text{Number of tokens in model_answer}} $ $ \text{Recall} = \frac{\text{Number of overlapping tokens}}{\text{Number of tokens in gold_answer}} $
- Symbol Explanation:
Number of overlapping tokens: The count of unique tokens (often after tokenization and normalization) that are present in both the model's answer and the gold-truth answer.Number of tokens in model_answer: The total count of unique tokens in the model's generated answer.Number of tokens in gold_answer: The total count of unique tokens in the ground-truth answer.
- Conceptual Definition: The
-
Jaccard Similarity:
- Conceptual Definition:
Jaccard Similarity(also known as the Jaccard Index or Intersection over Union) is a statistic used for gauging the similarity and diversity of sample sets. It is defined as the size of the intersection divided by the size of the union of the sample sets. In the context of text, it measures the overlap between two sets of words or tokens. - Mathematical Formula: For two sets and : $ J(A, B) = \frac{|A \cap B|}{|A \cup B|} $
- Symbol Explanation:
- : The set of tokens in the first text (e.g., student model's response).
- : The set of tokens in the second text (e.g., ground-truth response or teacher model's response).
- : The number of common tokens between sets and .
- : The total number of unique tokens across both sets and .
- Conceptual Definition:
5.3. Base Models and Baselines
Base Models (Student Models):
- LLaMA Family:
LLaMA3.2-1B-Instruct [34]LLaMA3.2-3B-Instruct [34]
- Qwen Family:
Qwen2.5-Math-1.5B [48]
Teacher Models:
- For LLaMA-based experiments:
LLaMA3.3-70B-Instruct [34] - For Qwen-based experiments:
Qwen2.5-72B-Instruct [48]These are chosen as high-capacity models that are expected to demonstrate strong reasoning capabilities and accurately identify critical information.
Baselines:
- Instruct Model (Pre-KD): The student model before any knowledge distillation (i.e., the original instruction-tuned version). This serves as a starting point to measure improvements.
- KD w/o Mask (Standard Knowledge Distillation): This is the primary baseline. It involves standard knowledge distillation where the student model is trained to mimic the teacher's outputs without any
confounding tokenidentification or pruning.- For math tasks, a
CoT-based variantis used. - For code generation and multi-hop QA, a
vanilla variantis used.
- For math tasks, a
5.4. Training and Evaluation Settings
Training Settings:
- Framework: Models are trained using the
Alpaca-LoRAframework, which typically involvesLow-Rank Adaptation (LoRA)for efficient fine-tuning. - Distillation:
Full-parameter logits knowledge distillationis used, implying that the student is trained to match the teacher's full output probability distribution (logits) rather than just the hard labels. - Learning Rate Schedule: A
cosine learning rate scheduleis employed, starting with a maximum learning rate of . - Epochs: Training is conducted for 3 epochs.
- Batch Size:
- LLaMA-based models: 64
- Qwen-based models: 32
- Detailed Hyperparameters: Further details are provided in Appendix J, including maximum sequence length (4096 for LLaMA-1B/Qwen-1.5B, 3000 for LLaMA-3B) and warmup steps (5).
Evaluation Settings:
- Decoding Strategy:
Greedy decodingis used for both teacher and student models during evaluation. - Generation Length:
- Code tasks: Maximum generation length of 1024 tokens.
- Math tasks: Maximum generation length of 16,384 tokens (reflecting the potentially longer reasoning chains).
- Chat Templates: Official
chat templatesare followed during inference to ensure consistency with how models are typically used. - Evaluation Frameworks:
- Math tasks: Evaluated using a modified
Step-DPO framework [24]. - Code generation: Evaluated using
EvalPlus [32](for HumanEval+ and LeetCode) andSkythought-Evals framework [25](for LiveCodeBench). - Multi-hop QA: Evaluated using the
LongMab-PO framework [12], with datasets fromLongBench [5].
- Math tasks: Evaluated using a modified
6. Results & Analysis
6.1. Core Results Analysis
The main results demonstrate LeaF's effectiveness in enhancing LLM reasoning across various tasks.
The following are the results from Table 1 of the original paper:
| Model | MathBench | CodeBench | ||||||
| GSM8K | MATH-500 | Olympiad- | Human- | Leet- | Livecode- | |||
| Bench | Avg. | Eval+ | Code | Bench | Avg. | |||
| Teacher Model | ||||||||
| LLaMA3.3-70B-Instruct | 95.60 | 70.40 | 36.50 | 67.50 | 78.05 | 53.90 | 45.02 | 58.99 |
| Qwen2.5-72B-Instruct | 95.45 | 73.80 | 41.25 | 70.17 | 81.71 | 69.40 | 54.42 | 68.51 |
| LLaMA3.2-1B-Instruct | ||||||||
| Instruct Model (Pre-KD) | 44.88 | 24.20 | 5.79 | 24.96 | 29.27 | 7.22 | 9.68 | 15.39 |
| KD w/o Mask | 56.79 | 33.40 | 8.90 | 33.03 | 32.32 | 6.11 | 13.74 | 17.39 |
| LeaF (Instr Mask) | 57.70 | 35.40 | 10.09 | 34.40 | 39.02 | 6.67 | 13.60 | 19.76 |
| LeaF (Instr & Resp Mask) | 58.98 | 35.20 | 9.94 | 34.71 | 39.63 | 7.22 | 12.48 | 19.77 |
| LLaMA3.2-3B-Instruct | ||||||||
| Instruct Model (Pre-KD) | 76.88 | 42.80 | 13.20 | 44.29 | 48.78 | 13.89 | 20.34 | 27.67 |
| KD w/o Mask | 82.87 | 49.00 | 18.99 | 50.29 | 54.88 | 16.67 | 24.12 | 31.89 |
| LeaF (Instr Mask) | 83.09 | 51.80 | 20.77 | 51.88 | 55.49 | 19.44 | 25.39 | 33.44 |
| LeaF (Instr & Resp Mask) | 84.69 | 52.40 | 22.55 | 53.21 | 56.10 | 21.67 | 25.81 | 34.53 |
| Qwen2.5-Math-1.5B | ||||||||
| Base Model (Pre-KD) | 65.20 | 41.40 | 21.96 | 42.85 | 35.37 | 6.67 | 1.26 | 14.43 |
| KD w/o Mask | 82.18 | 67.80 | 31.16 | 60.38 | 41.46 | 7.78 | 10.10 | 19.78 |
| LeaF (Instr Mask) | 84.69 | 68.60 | 32.79 | 62.03 | 42.68 | 9.94 | 10.80 | 20.97 |
| LeaF (Instr & Resp Mask) | 85.29 | 70.60 | 31.75 | 62.54 | 43.29 | 9.94 | 13.04 | 21.92 |
The following are the results from Table 2 of the original paper:
| Model | 2WikiMultiHopQA | Musique | HotpotQA | Avg. | ||||||
| EM | F1 | SubEM | EM | F1 | SubEM | EM | F1 | SubEM | ||
| LLaMA3.2-3B-Instruct | ||||||||||
| Instruct Model re-22.50 | 34.73 | 45.50 | 9.50 | 16.56 | 18.50 | 19.50 | 32.67 | 36.00 | 26.16 | |
| KD w/o Mask | 43.50 | 51.93 | 53.00 | 20.50 | 28.56 | 22.50 | 39.00 | 49.30 | 43.00 | 39.03 |
| LeaF (Instr Mask) | 46.50 | 53.89 | 55.00 | 2700 | 33.00 | 28.00 | 40.50 | 52.06 | 44.50 | 42.27 |
Several conclusions can be drawn from these tables:
- General Improvement with Distillation: Both
standard knowledge distillation (KD w/o Mask)andLeaFsignificantly improve the performance of smaller student models (LLaMA3.2-1B/3B-Instruct, Qwen2.5-Math-1.5B) across almost all math, code, and multi-hop QA tasks compared to theirPre-KDversions. This validates the effectiveness of knowledge distillation in transferring capabilities from large teacher models. - LeaF Outperforms Standard KD: LeaF consistently outperforms
standard knowledge distillation(KD w/o Mask) across all benchmarks.- MathBench (Avg. Accuracy): For LLaMA3.2-1B, KD (33.03%) -> LeaF Instr Mask (34.40%) -> LeaF Instr & Resp Mask (34.71%). For LLaMA3.2-3B, KD (50.29%) -> LeaF Instr Mask (51.88%) -> LeaF Instr & Resp Mask (53.21%). For Qwen2.5-Math-1.5B, KD (60.38%) -> LeaF Instr Mask (62.03%) -> LeaF Instr & Resp Mask (62.54%).
- CodeBench (Avg. Accuracy): Similar trends are observed with consistent gains. For LLaMA3.2-1B, KD (17.39%) -> LeaF Instr Mask (19.76%) -> LeaF Instr & Resp Mask (19.77%). For LLaMA3.2-3B, KD (31.89%) -> LeaF Instr Mask (33.44%) -> LeaF Instr & Resp Mask (34.53%). For Qwen2.5-Math-1.5B, KD (19.78%) -> LeaF Instr Mask (20.97%) -> LeaF Instr & Resp Mask (21.92%).
- Multi-hop QA (Avg. Score): LeaF (Instr Mask) achieves 42.27% average, outperforming KD w/o Mask (39.03%) for LLaMA3.2-3B.
This demonstrates that
causal attention distillationviagradient-guided token pruningis more effective than standard distillation in enhancing reasoning capabilities.
- Benefits of Response-Level Pruning: Extending LeaF to include
response-level pruning(LeaF (Instr & Resp Mask)) generally yields further performance improvements overinstruction-level pruning(LeaF (Instr Mask)) in most tasks for both LLaMA and Qwen series. This suggests thatdistracting patternscan also emerge during the model's generation process, affecting subsequent tokens, and addressing them is beneficial. The paper hypothesizes thatinstruction-levelandresponse-leveldistracting patterns may differ, and learning both can enhance reasoning. - Absolute Gains: LeaF achieves an average absolute accuracy gain of on MathBench, on CodeBench, and on multi-hop QA benchmarks compared to standard KD. These are significant improvements for complex reasoning tasks.
6.2. Ablation Studies / Parameter Analysis
6.2.1. Masking Strategies Analysis
This analysis investigates the effectiveness of LeaF's gradient-based masking strategy by comparing it against two alternatives: Random Masking and PPL-based Masking.

该图像是图表,展示了三种掩码策略(随机掩码、基于PPL掩码和基于梯度掩码)在不同数据集(GSM8K、MATH、OlympiadBench)及平均值上的准确率提升对比,基于梯度掩码策略表现最优。
Figure 7: Comparison of accuracy improvement with masking strategies over baseline (KD).
Results (Figure 7):
- Gradient-based Masking (LeaF) Superiority: LeaF's
Gradient-based Maskingconsistently outperforms bothRandom MaskingandPPL-based Masking. It shows the highest accuracy improvements, particularly onMATH-500andOlympiadBench, which are more complex tasks. - Random Masking Deterioration:
Random Maskingoften leads to performance degradation (e.g., on GSM8K and OlympiadBench) or only slight improvements (on MATH-500). This indicates that indiscriminately masking tokens without an informed selection process can be detrimental to distillation performance, as it might remove critical information. - PPL-based Masking Limitations:
PPL-based masking(masking tokens with highest perplexity) offers modest improvements on simpler tasks likeGSM8KandMATH-500but performs comparably toRandom Maskingon the more complexOlympiadBench. This suggests thatperplexityalone may not be sufficient for accurately identifyingconfounding tokensin challenging reasoning scenarios, highlighting the necessity of anadvanced teacher modelto guide token selection.
6.2.2. Response Splitting Strategies
The paper explores different strategies for processing responses, considering the presence of distracting patterns at both instruction and response levels.

该图像是图表,展示了图8中不同分割策略相较基线(KD)的准确率提升。纵轴为绝对提升百分比,不同颜色柱状对应未分割、分割为2段和3段的响应,显示在多个测试集上的表现差异。
Figure 8: Comparison of accuracy improvement with splitting strategies over baseline (KD).
Results (Figure 8):
- Response-level Pruning Benefits:
Response-level pruning(both2-segmentand3-segment splits) significantly outperformsinstruction-level pruning(response without split). This underscores the importance of extending theconfounding tokendetection and pruning beyond just the initial instruction to the model's generated response. This improvement is hypothesized to stem from the distinct nature ofdistracting patternsthat appear during generation. - Diminishing Returns of Further Segmentation: The performance of
3-segment splitsis comparable to that of2-segment splits. This suggests that beyond a certain point, additional segmentation at theresponse levelyields diminishing returns. The hypothesis is thatdistracting patternsat theresponse levelmay exhibit certain regularities, and data generated by2-segment splitsis sufficient for the model to learn these patterns effectively.
6.2.3. Threshold Sensitivity Analysis
This analysis examines the impact of the confounding token threshold () on model performance.

该图像是图表,展示了不同混淆阈值下,LLaMA3.2模型在MathBench任务中平均性能的对比,涉及KD和LeaF两种方法及不同模型规模。
Figure 9: Instruct-level in MathBench.

该图像是图表,展示了在 MathBench 基准测试中,不同阈值下多种模型(包括 LLaMA3.2-KD 和 LLaMA3.2-LeaF)的平均性能百分比。图中散点和虚线对比了两种方法在不同干扰阈值下的表现。
Figure 10: Response-level in MathBench.
Results (Figure 9 and Figure 10):
-
Optimal Thresholds:
- Instruct-level (Figure 9): LLaMA3.2-LeaF-1B performs best at , while LLaMA3.2-LeaF-3B performs best at .
- Response-level (Figure 10): LLaMA3.2-LeaF-1B performs best at , and LLaMA3.2-LeaF-3B at .
-
Model Size and Sensitivity: Smaller models (LLaMA3.2-LeaF-1B) generally achieve optimal performance at higher
misleading token thresholdsthan larger models (LLaMA3.2-LeaF-3B). This suggests that smaller models, being more susceptible toconfounding tokens, benefit from more aggressive filtering (higher thresholds) to remove disruptive tokens more effectively. -
Cross-Domain Stability: Detailed results in Appendix G (Table 7) further confirm that generally yields
stable performanceacross various tasks (math, code, QA), suggesting it's a robust default value.The following are the results from Table 7 of the original paper:
Task Model τ=0.05 τ=0.10 τ=0.15 MathBench LLaMA-3.2-1B-Instruct 32.43 34.40 33.17 LLaMA-3.2-3B-Instruct 51.88 51.07 50.55 CodeBench LLaMA-3.2-1B-Instruct 17.85 18.27 19.76 LLaMA-3.2-3B-Instruct 32.83 33.55 32.95 SQuAD 2.0 LLaMA-3.2-1B-Instruct 72.12 73.33 74.78 LLaMA-3.2-3B-Instruct 88.67 89.44 83.66
Table 7 demonstrates that provides a good balance, often yielding the best or near-best performance across diverse tasks for different model sizes.
6.2.4. Collective Masking vs. Span Masking (Appendix C)
This comparison evaluates the two pruning strategies for creating counterfactual samples.
The following are the results from Table 3 of the original paper:
| Model | MATH-500 |
| LLaMA3.2-1B-Instruct | |
| Instruct Model (Pre-KD) | 24.20 |
| KD w/o Mask | 34.00 |
| Collective Pruning | 34.20 |
| Span Pruning | 37.40 |
| LLaMA3.2-3B-Instruct | |
| Instruct Model (Pre-KD) | 42.80 |
| KD w/o Mask | 50.00 |
| Collective Pruning | 49.20 |
| Span Pruning | 54.40 |
Results (Table 3):
- Span Pruning Superiority:
Span Pruningsubstantially outperforms bothCollective Pruningand thenative knowledge distillation(KD w/o Mask). For LLaMA3.2-1B, Span Pruning (37.40%) significantly beats Collective Pruning (34.20%) and KD (34.00%). For LLaMA3.2-3B, Span Pruning (54.40%) also clearly surpasses Collective Pruning (49.20%) and KD (50.00%). - Collective Pruning Degradation:
Collective Pruningnot only fails to improve performance but actuallydegradesit for the LLaMA3.2-3B-Instruct model (49.20% vs. 50.00% for KD). The paper attributes this toCollective Pruningdisruptingsentence integrityby removing allconfounding tokenssimultaneously. This justifies the adoption ofSpan Pruningfor all experiments in LeaF.
6.2.5. Ablation Study: Importance of Contrastive Pairs (Appendix F)
This ablation studies the impact of including student-wrong originals in the training data, which LeaF leverages as part of its contrastive signal.
The following are the results from Table 6 of the original paper:
| Model | GSM8K | MATH-500 | OlympiadBench | Avg. |
| LLaMA 3.21B-Instruct | ||||
| LeaF | 57.70 | 35.40 | 10.09 | 34.40 |
| w/o Student-Wrong Originals | 58.15 | 34.80 | 7.42 | 33.46 |
| LLaMA 3.23B-Instruct | ||||
| LeaF | 83.09 | 51.80 | 20.77 | 51.88 |
| w/o Student-Wrong Originals | 84.08 | 47.80 | 16.02 | 49.30 |
Results (Table 6):
- Value of Student-Wrong Originals:
LeaFconsistently outperforms the ablated variant (w/o Student-Wrong Originals) on more challenging datasets likeMATH-500andOlympiadBench. This indicates that retainingoriginal sampleswhere the student model initially fails due tomisleading patterns(alongsidecounterfactual samples) provides a strongercontrastive signal. This signal helps the studentdownweight spurious correlationsand learn to attend tocausally relevant information. - Reduced Generalization: Excluding
student-wrong originalslimits the model's exposure to problematic cases, thereby hindering its ability to generalize and robustly perform underconfounding conditions. - GSM8K Anomaly: The ablated variant shows a slight performance gain on
GSM8K. This is attributed toGSM8Kbeing a simpler dataset. After filtering,student-correct samplesform a larger share of the training data, leading to adistributional biastowards easier problems, which the ablated model, focusing only on 'correct' examples, might temporarily benefit from.
6.2.6. Robustness Analysis (Appendix E)
This analysis evaluates LeaF's test-time robustness by assessing its performance on perturbed versions of the MathBench datasets (GSM8K, MATH-500, OlympiadBench). Perturbations are generated using back-translation to create realistic paraphrastic variations.
The following are the results from Table 5 of the original paper:
| Model | GSM8K | MATH-500 | OlympiadBench | Avg. |
| LLaMA 3.21B-Instruct | ||||
| Instruct (Pre-KD) | 39.65 | 24.80 | 3.86 | 22.77 |
| KD (no mask) | 50.42 | 31.80 | 5.19 | 29.14 |
| LeaF (Instr Mask) | 51.10 | 32.00 | 6.53 | 29.88 |
| LeaF (Instr & Resp Mask) | 51.71 | 34.20 | 6.97 | 30.96 |
| LLaMA 3.23B-Instruct | ||||
| Instruct (Pre-KD) | 70.43 | 40.80 | 9.64 | 40.29 |
| KD (no mask) | 74.00 | 48.20 | 14.09 | 45.43 |
| LeaF (Instr Mask) | 74.07 | 50.20 | 15.58 | 46.62 |
| LeaF (Instr & Resp Mask) | 76.12 | 51.20 | 19.88 | 49.07 |
Results (Table 5):
- Superior Robustness:
LeaFconsistently outperformsstandard KD(KD (no mask)) undernoisy conditions. This indicates that LeaF'spruning and attribution mechanismsare robust tomoderate linguistic perturbations. - Preserved Reasoning: The performance degradation of LeaF under noise remains within , suggesting that it effectively preserves its
reasoning capabilityeven wheninput distributions shift. These findings highlight LeaF'srobustnessand practical applicability in real-world scenarios involving noisy inputs ordistributional shifts.
6.2.7. Computational Overhead Analysis (Appendix D)
This section details the runtime and memory analysis of LeaF compared to standard Knowledge Distillation (KD).
The following are the results from Table 4 of the original paper:
| Model | KD Runtime (h) | LeaF Runtime (h) | Overhead (%) |
| LLaMA-3.2-1B-Instruct | 26.2 | 29.2 | +11.5 |
| LLaMA-3.2-3B-Instruct | 23.7 | 26.2 | +10.6 |
| Qwen2.5-Math-1.5B | 27.3 | 30.9 | +13.3 |
Results (Table 4):
- Moderate Training Overhead: LeaF incurs an
additional training timeof approximately compared tostandard KD. For example, LLaMA-3.2-1B-Instruct training time increases from 26.2 hours to 29.2 hours (+11.5%). - Justified Overhead: This moderate increase in training time is justified by the consistent absolute accuracy gains achieved across various benchmarks.
- Offline & Parallelizable Auxiliary Procedures: The auxiliary procedures of LeaF, including
gradient computation,gradient normalization,span pruning, andcounterfactual generation, are performedofflineand arefully parallelizable.Gradient computation: A one-time offline process taking about 3 hours for 7K samples on 8xNVIDIA A100 (80GB) GPUs.Counterfactual generation: An offline process taking 50 minutes for smaller models (26K samples) and up to 2.85 hours for larger ones on 4xNVIDIA A100 (80GB) GPUs. These steps incurno extra cost during online training or inference, making LeaF practical and scalable.
6.2.8. Case Study in an Interpretable Perspective
A case study is presented (Figure 11) to illustrate how LeaF improves interpretability by enabling the model to focus on critical information and avoid confounding tokens, compared to standard knowledge distillation (KD).
该图像是论文中的示意图,展示了LeaF框架的两阶段流程:基于梯度与教师模型对比识别混淆token,再通过蒸馏阶段剪枝这些token,实现因果注意力的对齐。
The case study involves a mathematical problem: "Let be a positive real number such that all the roots of are real. Find the smallest possible value of ."
Comparison (Figure 11):
- LLaMA3.2-3B-Instruct (Distilled by KD) - Left (Incorrect):
- The
KD modelstarts by incorrectly attempting to apply theAM-GM inequalityto the roots. - It misapplies AM-GM to potentially negative values (as roots can be negative, e.g., is a root), leading to an incorrect bound on (a ≤ ).
- The model overlooks the critical constraint that "all roots must be real" and follows a flawed reasoning chain, highlighted in blue, resulting in an incorrect final answer.
- The
- LLaMA3.2-3B-Instruct (Distilled by LeaF) - Right (Correct):
-
The
LeaF modelcorrectly identifies as an evident real root through factoring (Step 2: ). -
It then correctly applies the
discriminant condition(Step 3 and 4) to ensure thequadratic factor() also yieldsreal solutions. This involves setting the discriminant . -
Solving the inequality (Step 5 and 6) correctly leads to or .
-
Considering the problem's constraint that is a
positive real number, LeaF correctly deduces that the smallest possible value for is 3. -
The pink highlights mark the areas where LeaF correctly focuses on critical information and follows a coherent reasoning chain.
This case study visually demonstrates that
LeaFguides the model to attend to thecausally relevant partsof the problem (e.g., factoring, discriminant condition) and correctly interpret constraints, whereasstandard KDfails to prevent the model from being misled byspurious patternsor incorrect mathematical heuristics. This highlights LeaF's ability to create amore interpretableandreliable reasoning model.
-
6.3. Preliminary Experiments
6.3.1. Gradient Heatmap Comparison (Appendix A.1, Figure 12)
This preliminary experiment compares gradient heatmaps of a small student model and a large teacher model to visualize differences in their attention to critical tokens.

该图像是插图,展示了模型在数学题回答中的一个具体错误,其中公式显示为 并被标记为错误。
Figure 12: A case study comparing the performance of the student model (LlaMa3.2-1b-Instruct) and the teacher model (LlaMa3.3-70b-Instruct) on the MATH task.
Results (Figure 12):
- The
teacher model(Llama3.3-70B-Instruct) successfully captures a key contextual relation, like "Five containers of blueberries can be traded for two zucchinis," showingfocused attentionon relevant tokens. - The
student model(Llama3.2-1B-Instruct) exhibitsdispersed attention, meaning its focus is spread out and not precisely aligned with the critical information. This observation motivates the core hypothesis: by pruningdistracting patterns(confounding tokens), the student model can be guided to better focus onsalient information, thereby enhancing itsreasoning capabilities.
6.3.2. Performance Gains from Token Pruning (Appendix A.2, Figure 1)
This pilot study assesses the performance improvement of student models after simply removing distracting patterns (confounding tokens).

该图像是论文中图1,展示了去除困惑词后小模型在数学和代码训练语料上的准确率提升。结果显示数学语料提升超过20%,代码语料提升超过10%,体现了方法的有效性。
Figure 1: Accuracy improvements achieved by removing confounding tokens from small models on the math and code training corpora. The results demonstrate a significant increase in performance, with over improvement on the math corpus and more than on the code corpus. (For further details on these categories, see Appendix A.)
Results (Figure 1):
- Substantial Gains: Simply pruning
distracting patternswithout any additional training yields substantial improvements.- Over improvement in average accuracy on the
MATH training corpus. - More than improvement on the
Code training corpus.
- Over improvement in average accuracy on the
- Impact on Complexity: Greater improvements are observed on
AMC_AIME(more complex) compared toGSM8K(simpler). This suggests thatcomplex reasoning problemstend to contain moredistracting patternsthat interfere with model inference. These findings strongly support the idea that mitigating the influence ofdistracting patternsis crucial for improving therobustnessandaccuracyof LLM reasoning.
A further representative case is presented in Figure 2:

该图像是文本示意图,展示了细菌群落增长模式的逐步计算过程,重点突出在第11步计算出细菌数量1536,包含粉色和蓝色加粗文本强调关键内容。
Figure 2: Comparison of reasoning before and after pruning distracting patterns. Blue-shaded regions indicate pruned confounding tokens. Pink highlights mark areas that require focus, while blue highlights show where excessive attention caused errors.
Figure 2 illustrates a math problem about bacterial growth. Removing distracting patterns (blue-shaded regions) from the instruction helps the model focus on critical information and enhance reasoning. This visually reinforces the hypothesis that improving attention to relevant details is a promising direction.
6.3.3. Generation Quality Improvements from Token Pruning (Appendix A.3, Figure 13)
This analysis quantifies the similarity between student model outputs and teacher model outputs using Jaccard Similarity under two conditions: original instruction and instruction with pruned distracting patterns.

该图像是图表,展示了数学和代码数据集中学生模型响应与真实答案之间Jaccard相似度的概率分布比较,比较了原始指令与去除误导模式后的效果及其拟合曲线。
Figure 13: Jaccard similarity distribution between student model responses (original vs. instruction pruned distracting patterns) and ground-truth responses on math and code datasets.
Results (Figure 13):
- Improved Alignment: After removing
confounding tokensfrom the instruction, there is a clear shift in theJaccard Similarity distributionfor student model responses on bothcodeandmath tasks. The distribution shifts towards higher similarity values, and the mode of the distribution increases. This indicates that by ignoringdistracting patterns, the student model not only improvesreasoning accuracybut also generates responses that aremore alignedwith the teacher model's outputs, thereby enhancingoutput qualityand semantic coherence.
7. Conclusion & Reflections
7.1. Conclusion Summary
The paper introduces Learning to Focus (LeaF), a novel two-stage framework designed to enhance the reasoning consistency and interpretability of Large Language Models (LLMs). LeaF addresses the problem of LLMs being misled by distracting patterns or spurious correlations in their training data. By leveraging a causal analysis and a gradient-based pruning mechanism, LeaF effectively identifies confounding tokens that introduce these spurious correlations. In its first stage, it uses gradient-based comparisons between a teacher and student model to detect these tokens. In the second stage, it employs a hybrid distillation loss (combining standard and counterfactual distillation with span pruning) to align the student's attention with the teacher's focus on truly critical context tokens.
Experimental results across mathematical reasoning, code generation, and multi-hop question answering benchmarks consistently demonstrate that LeaF achieves significant absolute improvements in accuracy and robustness compared to standard knowledge distillation. Furthermore, case studies and ablation analyses confirm that LeaF successfully suppresses attention to confounding tokens, leading to a more interpretable and reliable reasoning model. The framework also shows benefits from response-level pruning, acknowledging that distracting patterns can appear throughout the generation process.
7.2. Limitations & Future Work
The authors identify the following limitations for LeaF:
- Dependence on an Advanced Teacher Model: The core mechanism of LeaF for identifying
confounding tokensrelies on the availability and superior performance of ateacher model. This means that if a strong teacher is not available, or if the teacher itself struggles with certain types of confounding, the effectiveness of LeaF might be limited.- Future Work: The authors suggest exploring
self-improvement mechanismswhere models can refine their attention tocritical tokensand boost reasoning without relying on an external advanced model. This would increase the framework's autonomy and applicability.
- Future Work: The authors suggest exploring
- Limited Scalability to Long-Form Generation: The current validation of LeaF is primarily on
mathandcode tasks, which, while complex, might not represent the full spectrum oflong-form text generation(e.g., creative writing, detailed explanations, summarization of very long documents).- Future Work: Investigating LeaF's applicability to
long-text generationand other diverse domains is proposed for future research.
- Future Work: Investigating LeaF's applicability to
7.3. Personal Insights & Critique
-
Strong Theoretical Grounding: The paper's adoption of a
causal frameworkto explain and mitigatespurious correlationsis a significant strength. Framing distracting patterns asconfoundersprovides a principled, theoretically sound approach that distinguishes it from many heuristic-based token selection methods. Thisinterventionist perspectiveis powerful for truly understanding and improving model behavior, rather than just optimizing performance. -
Interpretability as a Core Benefit: Beyond performance gains, the emphasis on
interpretabilitythrough attention visualization and case studies is highly valuable. Showing why a model makes a correct decision (by focusing on the right tokens) or an incorrect one (by being distracted) is crucial for building trust and for future model development. The contrastive case study in Figure 11 effectively illustrates this. -
Practicality and Scalability: The analysis of
computational overhead(Appendix D) demonstrates that while LeaF adds some training time, its auxiliary procedures areofflineandparallelizable. This ensures that the method is practical for real-world application, offering significant accuracy gains for a manageable increase in training cost, with no extra cost at inference time. -
Potential for Broader Application: While the paper notes limitations in
long-form generation, the core idea ofcausal attention distillationcould potentially be applied to many other domains where LLMs struggle with context or rely on superficial cues. Any task susceptible tospurious correlationsin data could benefit from this approach, such as complex scientific text understanding, legal document analysis, or even multimodal tasks where certain visual or auditory elements might act as confounders. -
The "Student-Wrong Originals" Insight: The ablation study highlighting the importance of including
student-wrong originalsis a subtle but critical insight. It shows that learning from failure cases, specifically when the student is misled byconfoundersand the teacher is not, provides a powerfulcontrastive signalthat strengthens the student's ability to resist spurious patterns. This goes beyond simply learning from 'good' examples. -
Further Refinements of Confounder Identification: While
gradient-based sensitivityis a strong method, further research could explore more nuanced ways to define and identifyconfounding tokens. For instance, integrating linguistic features or semantic parse information into the gradient analysis might refine the detection. Also, the reliance on athresholdstill introduces a hyperparameter that requires careful tuning. -
Self-Improvement Without a Teacher: The authors wisely point out the
teacher dependenceas a limitation. Developingself-improvement mechanismsthat allow a model to internally identify and mitigate its ownspurious correlationswould be a groundbreaking advancement, making such frameworks more autonomous and potentially unlocking even greater capabilities.Overall, LeaF presents a rigorous and impactful contribution to improving LLM reasoning, offering a clear path toward more reliable, accurate, and understandable AI systems.
Similar papers
Recommended via semantic vector search.