Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning
TL;DR 精炼摘要
本文提出LeaF框架,通过梯度引导的词元剪枝识别并移除训练数据中的混淆词元,消除虚假关联,促使学生模型聚焦于关键上下文,实现因果注意力蒸馏。该方法显著提升了数学推理、代码生成和多跳问答的准确性,增强模型推理的可靠性和可解释性。
摘要
Large language models (LLMs) have demonstrated significant improvements in contextual understanding. However, their ability to attend to truly critical information during long-context reasoning and generation still falls behind the pace. Specifically, our preliminary experiments reveal that certain distracting patterns can misdirect the model's attention during inference, and removing these patterns substantially improves reasoning accuracy and generation quality. We attribute this phenomenon to spurious correlations in the training data, which obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon may induce redundant reasoning processes, potentially resulting in significant inference overhead and, more critically, the generation of erroneous or suboptimal responses. To mitigate this, we introduce a two-stage framework called Learning to Focus (LeaF) leveraging intervention-based inference to disentangle confounding factors. In the first stage, LeaF employs gradient-based comparisons with an advanced teacher to automatically identify confounding tokens based on causal relationships in the training corpus. Then, in the second stage, it prunes these tokens during distillation to enact intervention, aligning the student's attention with the teacher's focus distribution on truly critical context tokens. Experimental results demonstrate that LeaF not only achieves an absolute improvement in various mathematical reasoning, code generation and multi-hop question answering benchmarks but also effectively suppresses attention to confounding tokens during inference, yielding a more interpretable and reliable reasoning model.
思维导图
论文精读
中文精读
1. 论文基本信息
1.1. 标题
学习聚焦:通过梯度引导的词元剪枝进行因果注意力蒸馏 (Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning)
1.2. 作者
Yiju Guo, Wenkai Yang, Zexu Sum, Ning Ding, Zhiyuan Liu, Yankai Lin
1.3. 发表期刊/会议
预印本 (Preprint),发表于 arXiv。
1.4. 发表年份
2025年。
1.5. 摘要
大型语言模型 (Large Language Models, LLMs) 在上下文理解方面取得了显著进展。然而,它们在长上下文推理和生成过程中聚焦真正关键信息的能力仍有不足。初步实验表明,某些干扰模式 (distracting patterns) 会在推理时误导模型注意力,移除这些模式能大幅提高推理准确性和生成质量。作者将此现象归因于训练数据中的虚假关联 (spurious correlations),这些关联阻碍了模型推断真实的因果指令-响应关系。这可能导致冗余的推理过程、显著的推理开销,更严重的是,可能产生错误或次优的响应。为解决此问题,论文提出了一个两阶段框架——学习聚焦 (Learning to Focus, LeaF),它利用基于干预的推理 (intervention-based inference) 来解耦混淆因子 (confounding factors)。在第一阶段,LeaF 通过与高级教师模型进行基于梯度的比较,根据训练语料中的因果关系自动识别混淆词元 (confounding tokens)。在第二阶段,它在蒸馏过程中剪枝这些词元以实施干预,使学生模型的注意力与教师模型在真正关键上下文词元上的聚焦分布对齐。实验结果表明,LeaF 不仅在数学推理、代码生成和多跳问答等多个基准测试中取得了绝对性能提升,而且在推理过程中有效抑制了对混淆词元的注意力,从而得到了更具可解释性和可靠性的推理模型。
1.6. 原文链接
https://arxiv.org/abs/2506.07851 PDF 链接: https://arxiv.org/pdf/2506.07851v2.pdf 发布状态:预印本 (Preprint)
2. 整体概括
2.1. 研究背景与动机
当前的大型语言模型 (LLMs) 尽管在上下文理解方面表现出色,但在处理长上下文推理和复杂指令任务时,仍难以有效聚焦于真正关键的信息。这种注意力分散或误导,会导致推理准确性和生成质量的下降。
论文的初步实验揭示了一个关键现象:训练数据中存在某些“干扰模式”或“混淆词元” (distracting patterns/confounding tokens)。这些模式与模型输出之间存在虚假关联 (spurious correlations),导致模型在推理时被误导,无法准确捕捉指令与响应之间真实的因果关系。例如,在数学问题中,一些与解题无关的冗余信息可能会分散模型的注意力,使其未能专注于核心逻辑。移除这些干扰模式后,模型的推理准确性甚至能提升超过20%,这强有力地证明了这些虚假关联的负面影响。
因此,论文的动机在于解决 LLMs 在复杂任务中“注意力不集中”的问题,即通过识别和消除训练数据中由虚假关联引起的混淆因子,引导模型学习更鲁棒和可解释的因果依赖关系,从而提高其推理能力和生成质量。
2.2. 核心贡献/主要发现
论文的核心贡献在于提出了一个名为 学习聚焦 (Learning to Focus, LeaF) 的两阶段框架,旨在提升 LLMs 在长上下文推理和生成中聚焦关键信息的能力,并取得了以下主要发现:
- 提出了基于因果推断的两阶段框架 LeaF: 该框架将干扰模式视为混淆因子,并利用干预 (intervention) 的思想来解耦这些因子。
- 第一阶段:混淆词元检测 (Confounding Token Detection): 通过基于梯度的师生模型敏感性比较,自动识别训练语料中引入虚假关联的混淆词元。随后,通过区间剪枝 (span pruning) 策略生成反事实样本 (counterfactual samples)。
- 第二阶段:因果注意力蒸馏 (Causal Attention Distillation): 采用混合蒸馏损失,同时对原始样本和反事实样本进行蒸馏,促使学生模型捕获真实的因果依赖关系,并使其注意力与教师模型在关键上下文词元上的聚焦分布对齐。
- 显著的性能提升: LeaF 在多个基准测试中表现出超越标准知识蒸馏 (Knowledge Distillation, KD) 的性能。
- 数学推理: 在 GSM8K、MATH-500 和 OlympiadBench 上平均准确率提升 2.41%。
- 代码生成: 在 HumanEval+、LeetCode 和 LivecodeBench 上平均准确率提升 2.48%。
- 多跳问答: 在 HotpotQA、2WikiMultiHopQA 和 Musique 上平均性能提升 3.24%。
- 增强模型可解释性和可靠性: 实验结果和案例分析(如
Figure 11)表明,LeaF 能有效抑制模型对混淆词元的注意力,使模型在推理时更专注于关键信息,从而产生更准确、更具解释性的推理过程。 - 响应级剪枝的有效性: 论文发现,将混淆词元剪枝的范围从指令级扩展到指令与响应级,能带来进一步的性能提升,表明响应生成过程中也存在干扰模式。
3. 预备知识与相关工作
3.1. 基础概念
为了更好地理解本论文,我们需要先了解以下核心概念:
- 大语言模型 (Large Language Models, LLMs): 指的是拥有数亿甚至数千亿参数的深度学习模型,它们在大量文本数据上进行预训练,展现出强大的自然语言理解、生成和推理能力。它们通常基于
Transformer架构,通过自注意力机制 (self-attention mechanism) 处理序列数据。 - 知识蒸馏 (Knowledge Distillation, KD): 是一种模型压缩和加速技术,旨在将一个大型、高性能的教师模型 (teacher model) 的知识迁移到一个小型、高效的学生模型 (student model) 中。通常通过最小化学生模型输出与教师模型输出之间的差异(例如,使用
KL 散度)来实现。 - 因果推断 (Causal Inference): 是一门研究如何从数据中识别和量化因果关系而非仅仅是统计关联的科学。本论文中引用了
Pearl's Structural Causal Model,它使用有向无环图 (Directed Acyclic Graph, DAG)来表示变量之间的因果结构。- 混淆因子 (Confounding Factors): 在因果推断中,混淆因子是一个影响自变量和因变量,并因此可能导致虚假关联的变量。在本论文中,
混淆词元 (confounding tokens)被定义为一种特殊的混淆因子,它们既与模型的输入 中的其他部分相关,又与模型的输出 相关,从而引入虚假关联,阻碍模型发现真实的因果指令-响应关系。 - 干预 (Intervention): 在因果推断中,干预是指主动改变某个变量的值,并观察这对其因果后代变量的影响,以揭示真实的因果关系。本论文通过剪枝混淆词元来模拟干预,以消除其对输出的虚假影响。
- 混淆因子 (Confounding Factors): 在因果推断中,混淆因子是一个影响自变量和因变量,并因此可能导致虚假关联的变量。在本论文中,
- 梯度 (Gradient): 在深度学习中,梯度是损失函数 (loss function) 相对于模型参数或输入变化的偏导数向量。它指示了损失函数增长最快的方向。在本论文中,梯度被用来衡量输入词元对模型输出损失的敏感性,即一个词元的小变化如何影响模型的预测结果。
梯度敏感性 (gradient sensitivity)越高,表示该词元对模型决策的影响越大。 - 词元 (Token): 是
LLM处理文本的最小单位。在自然语言处理中,文本通常被分解成词元,可以是单词、子词或字符。 - 困惑度 (Perplexity, PPL): 语言模型中衡量模型对给定文本序列预测能力的一个指标。困惑度越低,表示模型对文本的预测或拟合能力越强,认为该文本序列在模型下出现的可能性越高。
- KL 散度 (Kullback-Leibler Divergence, KL Divergence): 衡量两个概率分布之间差异的非对称指标。在本论文的知识蒸馏中,
KL 散度被用来量化学生模型输出的概率分布与教师模型输出的概率分布之间的差异。
3.2. 前人工作
论文将相关工作分为三个主要方面:
-
因果链知识蒸馏 (Chain-of-Thought Knowledge Distillation, CoT KD):
- 传统知识蒸馏
KD是一种将大模型知识迁移到小模型的技术。 CoTKD[41] 是最早将CoT(思维链) 解释从教师模型蒸馏到学生模型的工作,使学生模型获得高级推理能力。- 近期工作分为两类:
- 数据聚焦方法 (Data-focused approaches): 通过提升训练数据的质量和多样性来增强蒸馏效果,例如
CD[13]、SCORE[56] 和Skintern[28]。这些方法侧重于生成高质量的教师解释或筛选数据。 - 模型聚焦方法 (Model-focused approaches): 关注改进模型架构和推理策略以提高效率和推理能力,例如
PRR[41] 和ATM[7]。
- 数据聚焦方法 (Data-focused approaches): 通过提升训练数据的质量和多样性来增强蒸馏效果,例如
- 本文的差异: 现有方法主要关注输出模仿,而 LeaF 明确蒸馏教师模型在推理过程中捕获关键信息的能力,使学生模型具备上下文感知推理 (context-aware reasoning)。
- 传统知识蒸馏
-
关键词元识别 (Critic Token Identification):
- 现有工作探索了识别和缓解冗余步骤 [10, 33, 35] 或不重要词元 [15, 47] 的策略。
LLMLingua[23] 依赖模型自评估来判断词元重要性,但可能引入模型自身推理能力的偏见。RHO-1[29] 在SFT(监督微调) 阶段引入选择性语言建模 (Selective Language Modeling, SLM),优先处理信息量大的词元。TokenSkip[47] 聚焦于CoT阶段,选择性跳过低影响词元来压缩推理路径。cDPO[30] 通过对比学习 (contrastive learning) 隔离关键词元。- 本文的差异: 上述方法主要关注输出词元,而 LeaF 引入先进的教师模型,基于梯度差异识别混淆词元,强调指令阶段上下文信息的影响,建立了指令与输出之间更强的联系,实现了更全面的词元过滤。
-
推理一致性 (Reasoning Consistency):
Self-Consistency[46] 通过采样多条推理链并多数投票来稳定最终答案。Adaptive Consistency[2] 和Early Scoring Self-Consistency[27] 引入停止标准来降低推理成本。Reasoning Aware Self-Consistency[45] 通过加权样本质量和推理路径重要性来增强答案一致性。- 本文的差异: 这些方法主要关注稳定最终答案,而 LeaF 将一致性定义为答案稳定性和上下文依从性 (context adherence),并通过蒸馏策略系统性地抑制误导信号,增强学生模型对关键上下文信息的聚焦,促进更鲁棒的上下文依从性。
3.3. 技术演进
LLMs 在其发展初期展现了强大的语言理解和生成能力,但随着上下文长度的增加和任务复杂性的提高,模型难以持续聚焦于所有关键信息,容易受到冗余或误导性信息的干扰。这种“注意力不集中”导致了性能瓶颈和错误响应。
为了解决这一问题,研究社区开始探索多种方法。一方面,知识蒸馏被广泛用于将大型、高性能模型的推理能力迁移到小型模型,以提高效率。另一方面,为了提升推理的质量和效率,研究者开始关注如何识别和利用文本中的关键信息,例如通过 CoT 提示工程引导模型生成中间推理步骤,并通过蒸馏这些步骤来增强学生模型的推理能力。此外,还出现了多种词元选择和剪枝策略,旨在优化模型对输入信息的处理。
本文的工作 LeaF 正是建立在这些技术演进之上,但其创新之处在于,它不仅关注输出模仿或简单的词元重要性,而是从因果推断的视角出发,明确将“干扰模式”识别为“混淆因子”,并通过“干预”的机制(即剪枝这些词元)来强制模型学习真实的因果关系,从而在注意力机制层面进行更深层次的优化,提升模型的鲁棒性和可解释性。
3.4. 差异化分析
LeaF 方法与现有工作的主要区别和创新点在于:
- 因果视角与干预机制: 现有知识蒸馏和词元识别方法大多侧重于输出模仿、提升效率或筛选出“重要”词元。
LeaF则独辟蹊径,将“干扰模式”上升到因果混淆因子 (causal confounding factors) 的理论高度。通过识别这些混淆因子并对其进行干预 (intervention)(即剪枝),LeaF旨在消除虚假关联,迫使学生模型学习输入与输出之间真正的因果依赖 (causal dependencies),而非仅仅是统计关联。这是其与众不同的核心理论基础。 - 梯度引导的混淆词元识别:
LeaF利用教师模型与学生模型的梯度差异来精确识别混淆词元。这种方法比基于困惑度 (perplexity) 或模型自评估 (self-assessment) 的方法更具鲁棒性和准确性,尤其是在复杂推理任务中,因为它结合了更强大的教师模型的“真知灼见”。而现有关键词元识别方法通常不涉及跨模型的梯度比较来专门定位“混淆”信息。 - 反事实蒸馏 (Counterfactual Distillation): 论文引入了独特的反事实蒸馏损失,在剪枝了混淆词元后的样本上进行蒸馏。这与传统的只在原始样本上进行蒸馏不同,它通过对比“有混淆”和“无混淆”两种情况,强化学生模型识别和忽略混淆信息的能力,从而实现因果注意力对齐 (causal attention alignment)。
- 指令与响应级别的全面干预: 许多现有方法主要关注输出词元或推理链中的关键步骤。
LeaF不仅在指令级别 (instruction-level) 识别和剪枝混淆词元,还进一步扩展到响应级别 (response-level),处理生成过程中可能出现的干扰信息,从而实现更全面的注意力聚焦,这对于长上下文和多步推理任务尤为重要。
4. 方法论
4.1. 方法原理
LeaF (Learning to Focus) 的核心思想是利用因果推断来识别和消除大型语言模型 (LLMs) 推理过程中由虚假关联引起的混淆因子。该方法假设训练数据中存在某些词元,它们与模型的输入和期望输出之间存在非因果的虚假关联,从而误导模型注意力并降低推理准确性。为了解决这个问题,LeaF 框架采用梯度引导的词元剪枝和因果注意力蒸馏,使学生模型能够聚焦于真正关键的上下文词元,从而学习更鲁棒和可解释的因果关系。
其直觉在于:一个强大的教师模型通常能更好地理解输入中的真实因果关系,而学生模型可能更容易被虚假关联所误导。通过比较师生模型对特定词元的敏感性差异,可以定位那些学生模型过度关注但对教师模型影响较小的“混淆词元”。一旦识别出这些词元,便可以通过“干预”——即在训练数据中剪枝这些词元——来创建反事实情景,迫使学生模型在没有这些干扰的情况下进行学习,从而将注意力转移到真正有助于解决问题的关键信息上。
4.2. 核心方法详解
4.2.1. 因果框架
论文基于 Pearl's Structural Causal Model [39] 构建了一个 有向无环图 (Directed Acyclic Graph, DAG) 来建模推理过程中不同组件之间的因果关系。
下图(原文 Figure 3)展示了推理过程的因果图:
该图像是论文中展示的图3因果图示意图,描述了推理过程中输入提示 和模型输出 之间的因果关系。图中标注的混淆令牌 引入了干扰的伪相关,该方法通过检测并屏蔽 ,消除 到 的伪因果边,恢复真实因果依赖。
图 3: 推理过程的因果图。 表示输入提示 (input prompt), 表示模型的输出 (model's output)。 中的一部分词元,被识别为混淆词元 (A),引入了干扰推理过程的虚假关联。我们的方法检测并掩码 ,有效消除了从 到 的虚假边,并恢复了真实的因果依赖。
在这个框架中:
-
代表输入词元 (input tokens)。
-
代表模型的输出 (model's output)。
-
混淆词元 (confounding tokens)被定义为 的一个子集 ,它们通过与输出 和输入 的互补部分建立虚假关联,从而模糊了真实的因果关系。这些误导性依赖会扭曲模型的注意力机制,并偏向其推理过程,最终导致不可靠的预测。当混淆词元 同时影响 和 时(如
Figure 3中的虚线箭头所示),观察到的条件分布(observed conditional distribution)变为: 这个分布偏离了干预分布 (interventional distribution) ,反映了 通过虚假路径对 的间接影响所引入的偏差。其中, 表示在观察到输入 时输出 的概率, 表示通过干预将 设置为 时输出 的概率。do-operator(do 算子) 是因果推断中的一个核心概念,用于表示对某个变量进行干预,而不是简单地观察它。例如, 意味着我们强制 等于 ,而不是观察到 自然地等于 。
为了阻断这些非因果影响,论文提出了因果剪枝 (causal pruning),即在蒸馏之前移除 的影响。这鼓励学生模型学习基于真实因果结构的注意力模式,从而提高鲁棒性和可解释性。
4.2.2. LeaF: 学习聚焦框架
为了消除虚假依赖,论文引入了 Learning to Focus (LeaF) 框架,它由两个主要阶段组成:
下图(原文 Figure 4)展示了 LeaF 方法的概览:
该图像是示意图,展示了论文中提出的两阶段框架“Learning to Focus (LeaF)”的流程。第一阶段通过梯度比较识别混淆因子并进行裁剪;第二阶段进行因果注意力蒸馏,令学生模型学习聚焦重要上下文,提升推理准确性和模型可靠性。
图 4: 方法概览。训练流程包括两个关键阶段:(1) 混淆词元检测 (Confounding Token Detection):使用先进教师模型和学生模型之间基于梯度的比较来识别训练样本中的混淆词元,并通过剪枝这些词元来构建反事实样本;(2) 因果注意力蒸馏 (Causal Attention Distillation):在蒸馏过程中分别剪枝识别出的混淆词元,以使学生的注意力与教师模型对齐,并捕获因果关系。这种有针对性的干预引导模型走向实际的因果依赖,增强了鲁棒性和可解释性。
这两个阶段分别是:
- 混淆词元检测 (Confounding Token Detection):LeaF 通过教师模型和学生模型之间基于梯度的比较来识别混淆词元,并通过剪枝这些词元来构建反事实样本。
- 因果注意力蒸馏 (Causal Attention Distillation):LeaF 通过一个混合蒸馏损失来捕获因果依赖关系,该损失使学生模型在原始样本和反事实样本上与教师模型对齐。
4.2.3. 混淆词元检测 (Confounding Token Detection)
为了识别引入虚假关联的混淆词元 ,论文采用了基于梯度的方法 [42, 43] 来定量衡量每个词元对模型输出 的影响。具体来说,论文利用了教师模型 和学生模型 的梯度敏感性 (gradient sensitivity)。论文特别关注学生模型预测错误但教师模型正确处理的数据实例,以隔离混淆词元。对于每个词元 ,计算两个模型预测的 logits 和黄金参考 (gold references) 之间的损失相对于 嵌入的梯度。
每个模型对词元 的梯度敏感性表示为:
- :教师模型对词元 的梯度敏感性。
- :学生模型对词元 的梯度敏感性。
- :在给定完整输入 和模型参数 的情况下,与词元 相关的损失。这里计算的是损失函数 相对于词元 嵌入的偏导数的绝对值。
- :教师模型的参数。
- :学生模型的参数。 这些梯度反映了每个模型对 扰动的敏感性。
为了在梯度尺度不同的模型之间进行词元级别的比较,论文对敏感性值进行 min-max 归一化:
- :归一化后的教师模型对词元 的敏感性。
- :归一化后的学生模型对词元 的敏感性。
- 和 :在当前输入序列中所有词元 的梯度敏感性的最小值和最大值。 归一化的目的是将不同模型的梯度敏感性值映射到相同的 [0, 1] 范围,以便进行公平比较。
为了识别混淆词元,论文通过计算每个词元的梯度差异来捕获教师模型和学生模型之间在词元级别注意力上的差异:
-
:词元 的归一化梯度差异。正值表示教师模型更敏感,负值表示学生模型更敏感。
为了确保梯度差异在不同实例之间的一致缩放,论文再次对梯度差异进行归一化,并根据以下两个条件将词元 分类为混淆词元:
-
条件 (i): 词元 在推理过程中获得了学生模型显著的关注,但教师模型对其关注度可忽略不计,形式化为:
- :一个通过对验证集进行统计分析确定的阈值。
- 这个表达式意味着,如果归一化后的梯度差异 (即教师模型对 的关注度减去学生模型对 的关注度) 低于某个阈值 ,则该词元被认为是混淆词元。直观上,这意味着学生模型对该词元过度关注( 较高),而教师模型则不那么关注( 较低),表明学生可能被该词元误导。
-
条件 (ii): 移除该词元后,教师模型和学生模型都能产生正确的预测。这个条件进一步验证了该词元确实是“混淆”的,而不是解决问题所必需的。
剪枝策略 (Pruning Strategies): 论文研究了两种从指令 中移除混淆词元的剪枝策略:
-
集体剪枝 (Collective Pruning): 移除所有被识别的混淆词元集合 ,得到 。
-
区间剪枝 (Span Pruning): 一次只移除一个连续的混淆词元区间 ,得到 。
下图(原文 Figure 5)展示了集体剪枝和区间剪枝的示意图:
该图像是图示图,展示了论文中图5所示的集体剪枝(Collective Pruning)和区间剪枝(Span Pruning)方法。通过不同颜色区域标识有效token、混淆token及干扰pattern区域,说明剪枝操作在训练语料中的具体应用。
图 5: 集体剪枝和区间剪枝的说明。
初步实验表明,区间剪枝优于集体剪枝 (详见附录 ),因为同时剪枝所有干扰模式可能会破坏句子的完整性。因此,论文通过区间剪枝策略构建反事实样本:
- :反事实数据集。
- :一个反事实样本,其中 表示从原始输入 中移除了第 个混淆词元区间 。
- :在原始输入 中识别出的混淆词元区间的数量。 这种数据增强鼓励模型学习对特定混淆因子不变的推理路径。
4.2.4. 因果注意力蒸馏 (Causal Attention Distillation)
在生成原始样本和反事实样本后,论文通过优化两个互补的蒸馏目标来引导学生模型学习真实的因果依赖关系:
-
标准蒸馏 (Standard Distillation): 使学生模型在原始指令上的输出分布与教师模型对齐:
- :标准蒸馏损失。
- :
KL 散度,衡量两个概率分布之间的差异。 - :教师模型在给定输入 时生成输出 的概率分布。
- :学生模型在给定输入 时生成输出 的概率分布。
-
反事实蒸馏 (Counterfactual Distillation): 使学生模型在反事实指令 (混淆词元被剪枝)上的输出分布与教师模型对齐:
-
:反事实蒸馏损失。
-
:教师模型在给定剪枝后的输入 时生成输出 的概率分布。
-
:学生模型在给定剪枝后的输入 时生成输出 的概率分布。
论文将这些目标与一个权重因子 融合:
-
- :总的混合蒸馏损失。
- :控制标准蒸馏和反事实蒸馏之间权衡的因子。 这个复合损失引导学生模型在保留语义知识的同时,强制学习真正的因果依赖关系。
响应拆分策略 (Response Splitting Strategies):
论文考虑了两种 LeaF 变体:
-
指令级剪枝 (Instruct-level Pruning): 仅在指令中检测和剪枝混淆词元,并在指令级别剪枝的样本上执行
LeaF。 -
指令与响应级剪枝 (Both Instruct- and Response-level Pruning): 将先前生成的词元视为上下文输入,并剪枝那些对后续生成具有误导性的词元,以帮助模型产生更准确的延续。因此,论文在指令和先前的生成中都检测和剪枝混淆词元。
下图(原文 Figure 6)展示了响应拆分策略的示意图:
该图像是一个示意图,展示了图6中不同的回复拆分策略,包括语言CoT、指令级剪枝和响应级剪枝(2段和3段拆分)。图中用白色高亮表示输入部分,蓝色下划线表示用于计算交叉熵损失的输出部分。
图 6: 响应拆分策略的说明:Language CoT、指令级剪枝和响应级剪枝(2段和3段拆分)。高亮白色区域表示输入,蓝色下划线区域表示用于计算交叉熵损失的输出。
5. 实验设置
5.1. 数据集
实验评估了 LeaF 方法在数学推理、代码生成和多跳问答任务上的有效性。
训练数据集:
- 数学推理 (Mathematical Reasoning): 从
NuminaMath-CoT[26] 数据集中随机选择 30k 实例,每个子集(Olympiads[16]、AMC_AIME[26]、GSM8K[8] 和MATH[17])各 3k 实例,以确保模型在不同任务中遇到等量的混淆词元。 - 代码生成 (Code Generation): 从
AceCode-87K[53] 数据集中随机选择 120k 实例。 - 多跳问答 (Multi-hop Question Answering): 构建自
Helmet[52] 中提供的KILT[40] 数据集,总计 3k 注释样本,平均来自HotpotQA[50]、NQ[1] 和PopQA[36],其中每个查询都明确链接到包含答案的对应黄金段落。
评估数据集:
- 数学任务 (Math Task): 选择三个不同难度级别的广泛使用的基准测试:
GSM8K[8]MATH-500[17]OlympiadBench[16]
- 代码领域 (Code Domain): 在以下基准测试上进行评估:
- [32]
LeetCode[9]LivecodeBench(v4) [22]
- 多跳问答 (Multi-hop Question Answering): 在以下三个代表性基准测试上进行评估:
HotpotQA[50]2WikiMultiHopQA[19]Musique[4]
数据集中的具体样本示例:
下图(原文 Figure 2)展示了去除混淆词元前后的推理过程对比。蓝色阴影区域表示被剪枝的混淆词元。粉色高亮标记需要聚焦的区域,而蓝色高亮显示了过度关注导致错误的地方。
该图像是文本示意图,展示了细菌群落增长模式的逐步计算过程,重点突出在第11步计算出细菌数量1536,包含粉色和蓝色加粗文本强调关键内容。
图 2: 剪枝干扰模式前后的推理过程对比。蓝色阴影区域表示被剪枝的混淆词元。粉色高亮标记需要聚焦的区域,而蓝色高亮显示了过度关注导致错误的地方。 Distracting Patterns Region (干扰模式区域): 在某个细菌群落中,细菌数量每天翻倍。群落从3个细菌开始,在第1天结束时有6个,依此类推。那么,在哪个第一天结束时,群落的细菌数量将超过100个?
Before Confounding Tokens Pruning (去除混淆词元前): Step 1: 识别细菌的初始数量和生长模式。群落从3个细菌开始,每天翻倍。 Step 2: 计算每天结束时的细菌数量,以找出何时超过100个。 Day 1: Day 2: Day 3: (注意这里是错误的,应为 ) Day 4: Day 5: Day 6: Step 3: 确定群落细菌数量超过100个的第一天。从计算中可以看出,群落细菌数量在第6天超过100个。 最终答案是: 6 0
After Confounding Tokens Pruning (去除混淆词元后): Step 1: 识别细菌的初始数量和生长模式。群落从3个细菌开始,每天翻倍。 Step 2: 计算每天结束时的细菌数量,以找出何时超过100个。 Day 1: Day 2: Day 3: Day 4: Day 5: Day 6: Step 3: 确定群落细菌数量超过100个的第一天。从计算中可以看出,群落细菌数量在第6天超过100个。 最终答案是: 6
上面这个样本示例说明了在数学推理任务中,原始指令中可能包含的“干扰模式”(如 cYanmaGllowb 这种无意义的词元,或者在 Day 3 计算中的 ÷ 2 错误操作)如何误导模型。去除这些混淆词元后,模型能够更准确地执行推理步骤,得到正确的结果。
5.2. 评估指标
对论文中出现的评估指标,提供如下详细说明:
-
准确率 (Accuracy)
- 概念定义: 准确率是最常见的分类评估指标之一,它衡量模型正确预测的样本数量占总样本数量的比例。在推理和问答任务中,通常指模型给出正确答案的比例。
- 数学公式:
- 符号解释:
Number of Correct Predictions: 模型正确预测的样本数量。Total Number of Predictions: 所有进行预测的样本总数。
-
Pass@k
- 概念定义:
Pass@k是代码生成任务中常用的评估指标,尤其在HumanEval等基准测试中。它衡量的是模型生成 个代码解决方案中,至少有一个能够通过所有单元测试的比例。Pass@1表示模型第一次尝试就生成了正确代码的概率。 - 数学公式 (通用公式,论文未直接给出): 其中, 是问题总数, 是为每个问题生成的候选解决方案总数, 是针对问题 通过测试用例的解决方案数量。
- 符号解释:
- : 测试问题总数。
- : 为每个问题生成的候选解决方案的总数。
- : 对于第 个问题,通过所有单元测试的正确解决方案的数量。
- : 组合数,表示从 个元素中选择 个元素的方案数。
- 概念定义:
-
EM (Exact Match)
- 概念定义:
Exact Match是问答任务中一个严格的评估指标,它要求模型生成的答案与参考答案完全一致(通常忽略大小写和标点符号)。如果模型答案与任何一个参考答案完全匹配,则得分为 1,否则为 0。 - 数学公式 (概念性定义,无通用数学公式):
- 符号解释:
- : 问题总数。
- : 指示函数,如果条件为真则为 1,否则为 0。
- : 模型对第 个问题的答案。
- : 第 个问题的参考答案。
- 概念定义:
-
F1 (F1 Score)
- 概念定义:
F1 分数是精确率 (Precision) 和召回率 (Recall) 的调和平均值,通常用于衡量问答或信息检索任务中模型性能。它能更好地平衡模型的查全率和查准率。在问答中,通常将模型答案和参考答案视为词袋,计算它们之间的词重叠度。 - 数学公式: 其中,
- 符号解释:
True Positives: 模型答案中同时存在于参考答案中的词元数量。False Positives: 模型答案中存在,但参考答案中不存在的词元数量。False Negatives: 参考答案中存在,但模型答案中不存在的词元数量。
- 概念定义:
-
Jaccard 相似度 (Jaccard Similarity)
- 概念定义:
Jaccard 相似度(或Jaccard 系数)用于衡量两个集合之间的相似性。它定义为两个集合交集的大小除以它们并集的大小。在文本相似度评估中,通常将两个文本(如模型响应和真实标注)视为词元集合。 - 数学公式:
- 符号解释:
- : 第一个文本的词元集合。
- : 第二个文本的词元集合。
- : 两个集合的交集大小(共同词元数量)。
- : 两个集合的并集大小。
- , : 集合 和 的大小(词元数量)。
- 概念定义:
5.3. 对比基线
论文将 LeaF 方法与标准知识蒸馏 (Knowledge Distillation, KD) 进行了比较,其中标准 KD 在没有进行词元剪枝 (no mask) 的情况下进行。
- 对于数学任务,基线采用基于
CoT(Chain-of-Thought) 的变体。 - 对于代码生成和多跳问答任务,基线采用香草 (vanilla) 变体。
5.4. 模型和超参数
基础模型 (Base Models):
实验在两个不同的模型家族上进行:LLaMA 家族 [38, 34] 和 Qwen 家族 [48],涵盖不同规模的模型。
- LLaMA 系列:
- 学生模型 (Student Models):
LLaMA3.2-1B-Instruct[34] 和LLaMA3.2-3B-Instruct[34]。 - 教师模型 (Teacher Model):
LLaMA3.3-70B-Instruct[34]。
- 学生模型 (Student Models):
- Qwen 系列:
- 学生模型 (Student Model):
Qwen2.5-Math-1.5B[48]。 - 教师模型 (Teacher Model):
Qwen2.5-72B-Instruct[48]。
- 学生模型 (Student Model):
训练和评估设置:
-
训练 (Training):
-
模型使用
Alpaca-LoRA框架进行训练。 -
采用全参数 logits 知识蒸馏 (full-parameter logits knowledge distillation)。
-
学习率调度器 (LR Scheduler):
cosine学习率调度,最大学习率为 。 -
训练周期 (Epochs): 3 个 epoch。
-
批次大小 (Batch Size):
LLaMA系列模型为 64,Qwen系列模型为 32。 -
详细超参数参见附录 中的
Table 10。以下是原文 Table 10 的结果:
Model Hyper-parameter Value LLaMA3.2-1B-Instruct LR 1 × 10−5 LR Scheduler cosine Batch Size 64 Epochs 3 Maximum Sequence Length 4096 Warmup Steps 5 Distill Loss Type KL Validation Set Size (Math) 1035 Validation Set Size (Code) 2000 LLaMA3.2-3B-Instruct LR 1 × 10-5 LR Scheduler cosine Batch Size 64 Epochs 3 Maximum Sequence Length 3000 Warmup Steps 5 Distill Loss Type KL Validation Set Size (Math) 1035 Validation Set Size (Code) 2000 Qwen2.5-Math-1.5B LR 1 × 10-5 LR Scheduler cosine Batch Size 32 Epochs 3 Maximum Sequence Length 4096 Warmup Steps 5 Distill Loss Type KL Validation Set Size (Math) 1200 Validation Set Size (Code) 2000
-
-
评估 (Evaluation):
- 教师模型和学生模型在数学、代码和问答任务上使用贪婪解码 (greedy decoding) 进行评估。
- 最大生成长度 (Maximum generation length): 代码任务为 1024 词元,数学任务为 16384 词元。
- 推理过程中遵循官方聊天模板 (official chat templates)。
- 详细评估设置参见附录 。
6. 实验结果与分析
6.1. 核心结果分析
论文通过全面的实验证明了 Learning to Focus (LeaF) 框架的有效性。
初步实验结果:去除混淆词元的显著效果 下图(原文 Figure 1)展示了去除混淆词元后小模型在数学和代码训练语料上的准确率提升:
该图像是论文中图1,展示了去除困惑词后小模型在数学和代码训练语料上的准确率提升。结果显示数学语料提升超过20%,代码语料提升超过10%,体现了方法的有效性。
图 1: 去除混淆词元后小模型在数学和代码训练语料上的准确率提升。结果表明性能显著提高,数学语料提升超过 20%,代码语料提升超过 10%。
-
Figure 1显示,简单地剪枝干扰模式就能带来显著的性能提升:在数学语料上平均准确率提高超过 20%,在代码语料上提高超过 10%。 -
在
AMC_AIME上比GSM8K有更大的改进,这表明复杂的推理问题可能包含更多干扰模型推理的模式。 -
这强有力地支持了论文的核心假设:减轻干扰模式的影响对于提高
LLM推理的鲁棒性和准确性至关重要。下图(原文 Figure 2)展示了去除混淆词元前后推理过程对比的案例研究:
该图像是文本示意图,展示了细菌群落增长模式的逐步计算过程,重点突出在第11步计算出细菌数量1536,包含粉色和蓝色加粗文本强调关键内容。
图 2: 剪枝干扰模式前后的推理过程对比。蓝色阴影区域表示被剪枝的混淆词元。粉色高亮标记需要聚焦的区域,而蓝色高亮显示了过度关注导致错误的地方。
-
Figure 2的案例研究直观地展示了去除指令中的干扰模式如何帮助模型聚焦于关键信息,从而改善推理。在没有额外训练的情况下,这一操作就能提升模型的推理能力。下图(原文 Figure 13)展示了学生模型响应(原始 vs. 指令剪枝干扰模式)与数学和代码数据集上的真实标注响应之间的
Jaccard相似度分布:
该图像是图表,展示了数学和代码数据集中学生模型响应与真实答案之间Jaccard相似度的概率分布比较,比较了原始指令与去除误导模式后的效果及其拟合曲线。
图 13: 学生模型响应(原始 vs. 指令剪枝干扰模式)与数学和代码数据集上的真实标注响应之间的 Jaccard 相似度分布。
Figure 13显示,在去除混淆词元后,学生模型在代码和数学任务上生成的响应的Jaccard相似度分布发生了变化。这表明通过忽略干扰模式,学生模型不仅提高了推理准确性,而且生成了与教师模型更一致的响应,从而提高了输出质量。
主要性能提升:LeaF 相较于标准 KD 的优势
以下是原文 Table 1 的结果:
| Model | MathBench | CodeBench | ||||||
| GSM8K MATH-500 | Olympiad- Bench | Avg. | Human- Eval+ | Leet- Code | Livecode- Bench | Avg. | ||
| Teacher Model | ||||||||
| LLaMA3.3-70B-Instruct | 95.60 | 70.40 | 36.50 | 67.50 | 78.05 | 53.90 | 45.02 | 58.99 |
| Qwen2.5-72B-Instruct | 95.45 | 73.80 | 41.25 | 70.17 | 81.71 | 69.40 | 54.42 | 68.51 |
| LLaMA3.2-1B-Instruct | ||||||||
| Instruct Model (Pre-KD) | 44.88 | 24.20 | 5.79 | 24.96 | 29.27 | 7.22 | 9.68 | 15.39 |
| KD w/o Mask | 56.79 | 33.40 | 8.90 | 33.03 | 32.32 | 6.11 | 13.74 | 17.39 |
| LeaF (Instr Mask) | 57.70 | 35.40 | 10.09 | 34.40 | 39.02 | 6.67 | 13.60 | 19.76 |
| LeaF (Instr & Resp Mask) | 58.98 | 35.20 | 9.94 | 34.71 | 39.63 | 7.22 | 12.48 | 19.77 |
| LLaMA3.2-3B-Instruct | ||||||||
| Instruct Model (Pre-KD) | 76.88 | 42.80 | 13.20 | 44.29 | 48.78 | 13.89 | 20.34 | 27.67 |
| KD w/o Mask | 82.87 | 49.00 | 18.99 | 50.29 | 54.88 | 16.67 | 24.12 | 31.89 |
| LeaF (Instr Mask) | 83.09 | 51.80 | 20.77 | 51.88 | 55.49 | 19.44 | 25.39 | 33.44 |
| LeaF (Instr & Resp Mask) | 84.69 | 52.40 | 22.55 | 53.21 | 56.10 | 21.67 | 25.81 | 34.53 |
| Qwen2.5-Math-1.5B | ||||||||
| Base Model (Pre-KD) | 65.20 | 41.40 | 21.96 | 42.85 | 35.37 | 6.67 | 1.26 | 14.43 |
| KD w/o Mask | 82.18 | 67.80 | 31.16 | 60.38 | 41.46 | 7.78 | 10.10 | 19.78 |
| LeaF (Instr Mask) | 84.69 | 68.60 | 32.79 | 62.03 | 42.68 | 9.94 | 10.80 | 20.97 |
| LeaF (Instr & Resp Mask) | 85.29 | 70.60 | 31.75 | 62.54 | 43.29 | 9.94 | 13.04 | 21.92 |
-
Table 1显示,LeaF在数学推理 (MathBench) 和代码生成 (CodeBench) 任务上,始终优于标准知识蒸馏 (KD w/o Mask)。- 在
LLaMA3.2-1B-Instruct模型上,LeaF (Instr & Resp Mask)在MathBench平均准确率达到 34.71%,高于KD w/o Mask的 33.03%;在CodeBench平均准确率达到 19.77%,高于KD w/o Mask的 17.39%。 - 在
LLaMA3.2-3B-Instruct模型上,LeaF (Instr & Resp Mask)在MathBench平均准确率达到 53.21%,高于KD w/o Mask的 50.29%;在CodeBench平均准确率达到 34.53%,高于KD w/o Mask的 31.89%。 - 在
Qwen2.5-Math-1.5B模型上,LeaF (Instr & Resp Mask)在MathBench平均准确率达到 62.54%,高于KD w/o Mask的 60.38%;在CodeBench平均准确率达到 21.92%,高于KD w/o Mask的 19.78%。
- 在
-
指令与响应级别剪枝 (Instr & Resp Mask) 效果最佳: 在大多数任务中,从指令级别扩展到响应级别的剪枝进一步提升了性能,这表明响应生成过程中也存在干扰模式,需要同时处理。
以下是原文 Table 2 的结果:
Model 2WikiMultiHopQA Musique HotpotQA Avg. EM F1 SubEM EM F1 SubEM EM F1 SubEM LLaMA3.2-3B-Instruct Instruct Model re-22.50 34.73 45.50 9.50 16.56 18.50 19.50 32.67 36.00 26.16 KD w/o Mask 43.50 51.93 53.00 20.50 28.56 22.50 39.00 49.30 43.00 39.03 LeaF (Instr Mask) 46.50 53.89 55.00 2700 33.00 28.00 40.50 52.06 44.50 42.27 -
Table 2展示了LeaF在多跳问答任务上的表现。由于多跳问答任务的响应通常较短,只评估了指令级剪枝 (LeaF (Instr Mask)) 变体。-
LLaMA3.2-3B-Instruct模型上,LeaF (Instr Mask)的平均性能达到 42.27%,显著高于KD w/o Mask的 39.03%。这些结果共同验证了
LeaF能够通过增强对关键信息的注意力来提升LLM的推理性能,这与论文的假设一致。
-
6.2. 掩码策略分析 (Masking Strategies Analysis)
论文通过比较 LeaF (基于梯度掩码) 与其他两种词元掩码策略——随机掩码 (Random Masking) 和基于困惑度掩码 (PPL-based Masking)——来验证其梯度引导掩码策略的有效性。
下图(原文 Figure 7)展示了不同掩码策略相对于基线 (KD) 的准确率提升比较:
该图像是图表,展示了三种掩码策略(随机掩码、基于PPL掩码和基于梯度掩码)在不同数据集(GSM8K、MATH、OlympiadBench)及平均值上的准确率提升对比,基于梯度掩码策略表现最优。
图 7: 不同掩码策略相对于基线 (KD) 的准确率提升比较。
结果分析:
-
梯度引导掩码 (Gradient-based Masking) (本文方法): 在
GSM8K、MATH-500和OlympiadBench等数据集上始终优于其他两种基线策略,尤其在MATH-500和OlympiadBench上取得了最高的准确率提升。这表明梯度引导的方法能更准确地识别出真正的混淆词元。 -
随机掩码 (Random Masking): 在
GSM8K和Olympiad上导致性能下降,尽管在MATH-500上略有改善。这说明简单地随机掩码词元,如果没有信息指导,反而会破坏模型的推理过程,损害蒸馏性能。 -
基于困惑度掩码 (PPL-based Masking): 在
GSM8K和MATH-500上提供了适度的改进,但在OlympiadBench等复杂任务上表现与随机掩码相当。这表明困惑度在简单场景下可能足以检测混淆词元,但在挑战性基准测试中缺乏必要的敏感性。结论: 梯度引导的掩码策略(利用教师模型的指导)在复杂推理场景下识别混淆词元方面至关重要,优于无信息量的随机掩码和仅依赖学生模型自身困惑度的掩码。
6.3. 响应拆分策略分析 (Response Splitting Strategies)
论文比较了三种响应拆分策略,以探究指令级别和响应级别干扰模式的影响。
下图(原文 Figure 8)展示了不同拆分策略相对于基线 (KD) 的准确率提升比较:
该图像是图表,展示了图8中不同分割策略相较基线(KD)的准确率提升。纵轴为绝对提升百分比,不同颜色柱状对应未分割、分割为2段和3段的响应,显示在多个测试集上的表现差异。
图 8: 不同拆分策略相对于基线 (KD) 的准确率提升比较。
结果分析:
-
响应级别剪枝的显著优势:
响应级别剪枝(包括 2-segment 和 3-segment splits) 显著优于指令级别剪枝。这表明将混淆词元学习从指令级别扩展到响应级别是重要且有益的。 -
干扰模式的差异性: 论文推测,指令级别和响应级别的干扰模式可能有所不同,结合两者的学习可以进一步增强模型的推理能力。
-
分割段数的影响: 3-segment splits 的性能与 2-segment splits 相当,这表明在响应级别进行更细粒度的分割可能回报递减。论文假设响应级别的干扰模式具有一定的规律性,并且 2-segment splits 生成的数据已经足以让模型有效学习这些模式,额外的分割变得不必要。
结论: 响应级别的干扰模式对后续生成影响显著,并且
LeaF能够通过扩展剪枝策略到响应级别来有效缓解其负面影响,从而进一步提升模型性能。
6.4. 阈值敏感性分析 (Threshold Sensitivity Analysis)
论文对用于混淆词元剪枝的阈值 进行了敏感性分析,评估了 LLaMA3.2-1B-Instruct 和 LLaMA3.2-3B-Instruct 作为学生模型的性能。
下图(原文 Figure 9)展示了 MathBench 中指令级别的阈值敏感性分析:
该图像是图表,展示了不同混淆阈值下,LLaMA3.2模型在MathBench任务中平均性能的对比,涉及KD和LeaF两种方法及不同模型规模。
图 9: MathBench 中指令级别的阈值敏感性分析。
下图(原文 Figure 10)展示了 MathBench 中响应级别的阈值敏感性分析:
该图像是图表,展示了在 MathBench 基准测试中,不同阈值下多种模型(包括 LLaMA3.2-KD 和 LLaMA3.2-LeaF)的平均性能百分比。图中散点和虚线对比了两种方法在不同干扰阈值下的表现。
图 10: MathBench 中响应级别的阈值敏感性分析。
结果分析:
-
指令级别 (Instruct-level):
LLaMA3.2-LeaF-1B在阈值 0.10 时表现最佳。LLaMA3.2-LeaF-3B在阈值 0.05 时表现最佳。
-
响应级别 (Response-level):
LLaMA3.2-LeaF-1B在阈值 0.15 时表现最佳。LLaMA3.2-LeaF-3B在阈值 0.10 时表现最佳。
-
模型规模与阈值关系:
LLaMA3.2-LeaF-1B(较小模型) 在指令级别和响应级别都倾向于在更高的阈值下获得最佳性能,而LLaMA3.2-LeaF-3B(较大模型) 倾向于在较低或中等阈值下表现最佳。这表明较小的模型更容易受到混淆词元的影响,因此需要更高的阈值来更有效地过滤掉干扰性词元。 -
跨任务稳定性: 附录 中的详细跨领域结果进一步证实, 在不同任务中能产生稳定的性能。
以下是原文 Table 7 的结果:
Task Model τ=0.05 τ=0.10 τ=0.15 MathBench LLaMA-3.2-1B-Instruct 32.43 34.40 33.17 LLaMA-3.2-3B-Instruct 51.88 51.07 50.55 CodeBench LLaMA-3.2-1B-Instruct 17.85 18.27 19.76 LLaMA-3.2-3B-Instruct 32.83 33.55 32.95 SQuAD 2.0 LLaMA-3.2-1B-Instruct 72.12 73.33 74.78 LLaMA-3.2-3B-Instruct 88.67 89.44 83.66 -
Table 7进一步证实了 在数学、代码和问答任务中的稳定性,它通常能达到最佳或接近最佳的性能。结论: 阈值 的选择对模型性能有影响,且较小模型通常受益于更高的阈值。 是一个在各种任务中都能提供稳定性能的良好折衷点。
6.5. 可解释性案例研究 (Case Study in an Interpretable Perspective)
论文通过一个案例研究展示了 LeaF 相较于标准知识蒸馏 (KD) 如何使模型更好地聚焦于关键信息并避免混淆词元,从而提高可解释性。
下图(原文 Figure 11)展示了 LeaF 和 KD 在 MATH500 上的性能对比案例研究。顶部是 KD 和 LeaF 在指令词元上的注意力分数差异热图。深蓝色表示 KD 注意力更高,深粉色表示 LeaF 注意力更高。左侧是标准知识蒸馏后 LLaMA-3.2-3B-Instruct 的响应,蓝色文本标记了 KD 模型的错误步骤。右侧是 LLaMA-3.2-LeaF-3B 的响应,粉色文本高亮显示了本文模型的正确步骤。
该图像是论文中的示意图,展示了LeaF框架的两阶段流程:基于梯度与教师模型对比识别混淆token,再通过蒸馏阶段剪枝这些token,实现因果注意力的对齐。
图 11: LeaF 和 KD 在 MATH500 上的性能对比案例研究。顶部:热图显示 KD 和 LeaF 在指令词元上的注意力分数差异。深蓝色表示 KD 注意力更高,深粉色表示 LeaF 注意力更高。左侧:标准知识蒸馏后 LLaMA-3.2-3B-Instruct 的响应,蓝色文本标记了 KD 模型的错误步骤。右侧:LLaMA-3.2-LeaF-3B 的响应,粉色文本高亮显示了本文模型的正确步骤。
在给定“所有根都是实数”的条件下,找出 中 的最小可能值。
标准知识蒸馏模型 (LlaMa3.2-3b-Instruct (Distilled by KD)) 的响应:
- Step 1: 尝试应用
AM-GM 不等式到根上。 - Step 2: 利用 Vieta's formulas 得到 和 。
- Step 3: 代入
AM-GM不等式,得到 。因此 . - Step 4 & 5: 认为当 时等式成立,并计算出 。
- 错误分析:
KD模型在 Step 1 就走错了方向,它试图直接对根应用AM-GM 不等式。然而,AM-GM 不等式仅适用于非负实数。题目中并没有限定根是非负的,并且通过Vieta's formulas得到的 暗示了根可能是负数,因此直接应用AM-GM是错误的。这导致后续的推理链完全偏离,最终得出错误的答案。
LeaF 模型 (LIaMa3.2-3b-Instruct (Distilled by LeaF)) 的响应:
-
Step 1: 识别问题要求找出使所有根为实数的 的最小可能值。
-
Step 2: 将立方方程因式分解为 。这正确地识别出 是一个实根。
-
Step 3: 为了使所有根为实数,二次因子 的判别式必须是非负的。
-
Step 4: 设置判别式 。
-
Step 5: 解不等式 ,得到 或 ,即 或 。
-
Step 6: 结合题目中 是正实数的条件,得到 。因此, 的最小可能值为 3。
-
正确分析:
LeaF模型首先识别了 这个显而易见的实根,然后正确地将问题转化为二次方程的判别式条件。它遵循了一个连贯且正确的推理链,并最终考虑了 为正实数的约束,得到了正确答案。结论:
LeaF模型通过有效抑制对误导性信息的注意力(如AM-GM不等式在不适用情况下的诱导),使模型能够聚焦于关键的约束条件(所有根为实数)和正确的数学原理(因式分解和判别式),从而生成更准确、更可解释的推理过程。而KD模型则被虚假关联误导,未能正确应用数学工具。
6.6. 鲁棒性分析 (Robustness Analysis)
论文通过在扰动版本的 MathBench 基准测试(GSM8K、MATH-500 和 OlympiadBench)上评估 LeaF 来分析其测试时鲁棒性 (test-time robustness)。扰动通过回译 (back-translation) 生成,这是一种用于生成现实性释义变体的标准数据增强技术。
以下是原文 Table 5 的结果:
| Model | GSM8K | MATH-500 | OlympiadBench | Avg. |
| LLaMA 3.21B-Instruct | ||||
| Instruct (Pre-KD) | 39.65 | 24.80 | 3.86 | 22.77 |
| KD (no mask) | 50.42 | 31.80 | 5.19 | 29.14 |
| LeaF (Instr Mask) | 51.10 | 32.00 | 6.53 | 29.88 |
| LeaF (Instr & Resp Mask) | 51.71 | 34.20 | 6.97 | 30.96 |
| LLaMA 3.23B-Instruct | ||||
| Instruct (Pre-KD) | 70.43 | 40.80 | 9.64 | 40.29 |
| KD (no mask) | 74.00 | 48.20 | 14.09 | 45.43 |
| LeaF (Instr Mask) | 74.07 | 50.20 | 15.58 | 46.62 |
| LeaF (Instr & Resp Mask) | 76.12 | 51.20 | 19.88 | 49.07 |
结果分析:
-
Table 5显示,LeaF在噪声条件下始终优于标准KD。例如,在LLaMA3.2-3B-Instruct模型上,LeaF (Instr & Resp Mask)在平均准确率上达到 49.07%,高于KD (no mask)的 45.43%。 -
性能下降在 2-3% 范围内,这表明
LeaF的剪枝和归因机制对适度的语言扰动具有鲁棒性。 -
即使输入分布发生变化,
LeaF仍能有效保持推理能力。结论:
LeaF在面对输入扰动或分布噪声时表现出更强的鲁棒性,这对于实际应用场景至关重要。
6.7. 消融实验 (Ablation Study): 对比对的重要性
论文进行了一项消融实验,探究了排除学生模型因混淆词元而未能解决的原始样本 (student-wrong originals) 对性能的影响。LeaF 框架则同时利用了学生模型正确解决的原始样本、学生模型错误解决的原始样本以及反事实样本。
以下是原文 Table 6 的结果:
| Model | GSM8K | MATH-500 | OlympiadBench | Avg. |
| LLaMA 3.21B-Instruct | ||||
| LeaF | 57.70 | 35.40 | 10.09 | 34.40 |
| w/o Student-Wrong Originals | 58.15 | 34.80 | 7.42 | 33.46 |
| LLaMA 3.23B-Instruct | ||||
| LeaF | 83.09 | 51.80 | 20.77 | 51.88 |
| w/o Student-Wrong Originals | 84.08 | 47.80 | 16.02 | 49.30 |
结果分析:
-
LeaF优于消融变体: 在MATH-500和OlympiadBench上,LeaF性能优于排除了学生模型错误原始样本的变体 (w/o Student-Wrong Originals)。这表明保留包含误导性模式的原始样本与反事实样本一起,能提供更强的对比信号。这种对比帮助学生模型降低虚假关联的权重,并关注因果相关的真实信息。 -
排除错误样本的限制: 排除学生模型错误原始样本会限制模型接触误导性情况,从而限制了模型在混淆条件下泛化的能力。
-
GSM8K上的例外: 消融变体在GSM8K上略有性能提升,这归因于数据集的构成。GSM8K是一个相对简单的任务,过滤后其学生模型正确解决的样本在训练数据中占据更大比例。这种分布偏向于简单问题,导致在GSM8K上取得更高的准确率。结论: 包含学生模型错误解决的原始样本对于
LeaF学习识别和避免混淆信息至关重要,因为它提供了必要的对比学习信号。
6.8. 计算开销分析 (Computational Overhead Analysis)
论文报告了 LeaF 与标准知识蒸馏 (KD) 在相同硬件配置下的详细运行时和内存分析。
梯度计算 (Gradient Computation):
- 这是一个一次性的离线过程,在 8x NVIDIA A100 (80GB) GPU 上联合计算教师模型和学生模型的梯度。
- 处理约 7K 样本需要约 3 小时。
反事实生成 (Counterfactual Generation):
- 反事实响应是离线生成的,使用
vLLM在 4x NVIDIA A100 (80GB) GPU 上进行。 - 处理 26K 样本:
- 小型模型 (如
LLaMA-3.2-1B-Instruct) 需要约 50 分钟。 - 大型模型 (如
LLaMA3.3-70B-Instruct或Qwen2.5-72B-Instruct) 需要长达 2.85 小时。
- 小型模型 (如
- 此步骤可并行化,且只需执行一次。
训练开销 (Training Overhead): 论文测量了在 4x NVIDIA A100 (80GB) GPU 上进行 3 个 epoch 的端到端训练开销。
以下是原文 Table 4 的结果:
| Model | KD Runtime (h) | LeaF Runtime (h) | Overhead (%) |
| LLaMA-3.2-1B-Instruct | 26.2 | 29.2 | +11.5 |
| LLaMA-3.2-3B-Instruct | 23.7 | 26.2 | +10.6 |
| Qwen2.5-Math-1.5B | 27.3 | 30.9 | +13.3 |
结果分析:
-
Table 4显示,LeaF相较于标准KD引入了中等程度的训练开销,额外训练时间为 10-13%。 -
尽管有额外的计算开销,
LeaF在数学、代码和多跳问答基准测试上持续带来了 2-3% 的绝对准确率提升。 -
LeaF的辅助过程(包括梯度归一化、区间剪枝和反事实生成)都是一次性、离线且完全可并行化的,因此在训练或推理过程中不会产生额外成本。结论:
LeaF的额外计算成本是可接受的,并且其带来的性能提升证明了这种额外计算的有效利用。其离线和并行化的特性也保证了其在大规模应用中的实用性和可扩展性。
6.9. 数据集统计 (Dataset Statistics): 反事实样本和原始样本
论文报告了训练中使用的原始样本和反事实样本的词元长度统计和样本数量。
以下是原文 Table 8 的结果:
| Model | Min | Max | Avg. | Total | AIME | MATH | GSM8K | OlympiadBench |
| Original Samples | 23 | 2711 | 98.50 | 12000 | 3000 | 3000 | 3000 | 3000 |
| LLaMA-3.2-1B CF | 24 | 2710 | 160.78 | 4576 | 1620 | 1029 | 907 | 1020 |
| LLaMA-3.2-3B CF | 23 | 2711 | 148.79 | 2843 | 1228 | 709 | 172 | 734 |
以下是原文 Table 9 的结果:
| Model | Min | Max | Avg. | Total |
| Original Samples | 34 | 527 | 140.17 | 12000 |
| LLaMA-3.2-1B CF | 47 | 507 | 158.61 | 6210 |
| LLaMA-3.2-3B CF | 46 | 446 | 159.72 | 4059 |
结果分析:
-
Table 8和Table 9总结了主要实验中使用的训练数据的分布特性。 -
观察到反事实样本 (
CF) 的平均词元长度略高于原始样本,无论是在数学任务还是代码任务中。 -
这种模式可能因为更长、更复杂的问题往往包含更多对学生模型具有误导性的模式,从而增加了生成反事实样本的可能性。
结论: 反事实样本的生成与原始样本的复杂性和长度相关,反映了模型在处理复杂长文本时更容易遇到混淆信息。
6.10. 数据呈现 (表格)
本小节将完整呈现论文中涉及实验结果的所有表格。
以下是原文 Table 1 的结果:
| Model | MathBench | CodeBench | ||||||
| GSM8K MATH-500 | Olympiad- Bench | Avg. | Human- Eval+ | Leet- Code | Livecode- Bench | Avg. | ||
| Teacher Model | ||||||||
| LLaMA3.3-70B-Instruct | 95.60 | 70.40 | 36.50 | 67.50 | 78.05 | 53.90 | 45.02 | 58.99 |
| Qwen2.5-72B-Instruct | 95.45 | 73.80 | 41.25 | 70.17 | 81.71 | 69.40 | 54.42 | 68.51 |
| LLaMA3.2-1B-Instruct | ||||||||
| Instruct Model (Pre-KD) | 44.88 | 24.20 | 5.79 | 24.96 | 29.27 | 7.22 | 9.68 | 15.39 |
| KD w/o Mask | 56.79 | 33.40 | 8.90 | 33.03 | 32.32 | 6.11 | 13.74 | 17.39 |
| LeaF (Instr Mask) | 57.70 | 35.40 | 10.09 | 34.40 | 39.02 | 6.67 | 13.60 | 19.76 |
| LeaF (Instr & Resp Mask) | 58.98 | 35.20 | 9.94 | 34.71 | 39.63 | 7.22 | 12.48 | 19.77 |
| LLaMA3.2-3B-Instruct | ||||||||
| Instruct Model (Pre-KD) | 76.88 | 42.80 | 13.20 | 44.29 | 48.78 | 13.89 | 20.34 | 27.67 |
| KD w/o Mask | 82.87 | 49.00 | 18.99 | 50.29 | 54.88 | 16.67 | 24.12 | 31.89 |
| LeaF (Instr Mask) | 83.09 | 51.80 | 20.77 | 51.88 | 55.49 | 19.44 | 25.39 | 33.44 |
| LeaF (Instr & Resp Mask) | 84.69 | 52.40 | 22.55 | 53.21 | 56.10 | 21.67 | 25.81 | 34.53 |
| Qwen2.5-Math-1.5B | ||||||||
| Base Model (Pre-KD) | 65.20 | 41.40 | 21.96 | 42.85 | 35.37 | 6.67 | 1.26 | 14.43 |
| KD w/o Mask | 82.18 | 67.80 | 31.16 | 60.38 | 41.46 | 7.78 | 10.10 | 19.78 |
| LeaF (Instr Mask) | 84.69 | 68.60 | 32.79 | 62.03 | 42.68 | 9.94 | 10.80 | 20.97 |
| LeaF (Instr & Resp Mask) | 85.29 | 70.60 | 31.75 | 62.54 | 43.29 | 9.94 | 13.04 | 21.92 |
以下是原文 Table 2 的结果:
| Model | 2WikiMultiHopQA | Musique | HotpotQA | Avg. | ||||||
| EM | F1 | SubEM | EM | F1 | SubEM | EM | F1 | SubEM | ||
| LLaMA3.2-3B-Instruct | ||||||||||
| Instruct Model re-22.50 | 34.73 | 45.50 | 9.50 | 16.56 | 18.50 | 19.50 | 32.67 | 36.00 | 26.16 | |
| KD w/o Mask | 43.50 | 51.93 | 53.00 | 20.50 | 28.56 | 22.50 | 39.00 | 49.30 | 43.00 | 39.03 |
| LeaF (Instr Mask) | 46.50 | 53.89 | 55.00 | 2700 | 33.00 | 28.00 | 40.50 | 52.06 | 44.50 | 42.27 |
以下是原文 Table 3 的结果:
| Model | MATH-500 |
| LLaMA3.2-1B-Instruct | |
| Instruct Model (Pre-KD) | 24.20 |
| KD w/o Mask | 34.00 |
| Collective Pruning | 34.20 |
| Span Pruning | 37.40 |
| LLaMA3.2-3B-Instruct | |
| Instruct Model (Pre-KD) | 42.80 |
| KD w/o Mask | 50.00 |
| Collective Pruning | 49.20 |
| Span Pruning | 54.40 |
以下是原文 Table 4 的结果:
| Model | KD Runtime (h) | LeaF Runtime (h) | Overhead (%) |
| LLaMA-3.2-1B-Instruct | 26.2 | 29.2 | +11.5 |
| LLaMA-3.2-3B-Instruct | 23.7 | 26.2 | +10.6 |
| Qwen2.5-Math-1.5B | 27.3 | 30.9 | +13.3 |
以下是原文 Table 5 的结果:
| Model | GSM8K | MATH-500 | OlympiadBench | Avg. |
| LLaMA 3.21B-Instruct | ||||
| Instruct (Pre-KD) | 39.65 | 24.80 | 3.86 | 22.77 |
| KD (no mask) | 50.42 | 31.80 | 5.19 | 29.14 |
| LeaF (Instr Mask) | 51.10 | 32.00 | 6.53 | 29.88 |
| LeaF (Instr & Resp Mask) | 51.71 | 34.20 | 6.97 | 30.96 |
| LLaMA 3.23B-Instruct | ||||
| Instruct (Pre-KD) | 70.43 | 40.80 | 9.64 | 40.29 |
| KD (no mask) | 74.00 | 48.20 | 14.09 | 45.43 |
| LeaF (Instr Mask) | 74.07 | 50.20 | 15.58 | 46.62 |
| LeaF (Instr & Resp Mask) | 76.12 | 51.20 | 19.88 | 49.07 |
以下是原文 Table 6 的结果:
| Model | GSM8K | MATH-500 |
| LLaMA 3.21B-Instruct | ||
| LeaF | 57.70 | 35.40 |
| w/o Student-Wrong Originals | 58.15 | 34.80 |
| LLaMA 3.23B-Instruct | ||
| LeaF | 83.09 | 51.80 |
| w/o Student-Wrong Originals | 84.08 | 47.80 |
以下是原文 Table 7 的结果:
| Task | Model | τ=0.05 | τ=0.10 | τ=0.15 |
| MathBench | LLaMA-3.2-1B-Instruct | 32.43 | 34.40 | 33.17 |
| LLaMA-3.2-3B-Instruct | 51.88 | 51.07 | 50.55 | |
| CodeBench | LLaMA-3.2-1B-Instruct | 17.85 | 18.27 | 19.76 |
| LLaMA-3.2-3B-Instruct | 32.83 | 33.55 | 32.95 | |
| SQuAD 2.0 | LLaMA-3.2-1B-Instruct | 72.12 | 73.33 | 74.78 |
| LLaMA-3.2-3B-Instruct | 88.67 | 89.44 | 83.66 |
以下是原文 Table 8 的结果:
| Model | Min | Max | Avg. | Total | AIME | MATH | GSM8K | OlympiadBench |
| Original Samples | 23 | 2711 | 98.50 | 12000 | 3000 | 3000 | 3000 | 3000 |
| LLaMA-3.2-1B CF | 24 | 2710 | 160.78 | 4576 | 1620 | 1029 | 907 | 1020 |
| LLaMA-3.2-3B CF | 23 | 2711 | 148.79 | 2843 | 1228 | 709 | 172 | 734 |
以下是原文 Table 9 的结果:
| Model | Min | Max | Avg. | Total |
| Original Samples | 34 | 527 | 140.17 | 12000 |
| LLaMA-3.2-1B CF | 47 | 507 | 158.61 | 6210 |
| LLaMA-3.2-3B CF | 46 | 446 | 159.72 | 4059 |
以下是原文 Table 10 的结果:
| Model | Hyper-parameter | Value |
| LLaMA3.2-1B-Instruct | LR | 1 × 10−5 |
| LR Scheduler | cosine | |
| Batch Size | 64 | |
| Epochs | 3 | |
| Maximum Sequence Length | 4096 | |
| Warmup Steps | 5 | |
| Distill Loss Type | KL | |
| Validation Set Size (Math) | 1035 | |
| Validation Set Size (Code) | 2000 | |
| LLaMA3.2-3B-Instruct | LR | 1 × 10-5 |
| LR Scheduler | cosine | |
| Batch Size | 64 | |
| Epochs | 3 | |
| Maximum Sequence Length | 3000 | |
| Warmup Steps | 5 | |
| Distill Loss Type | KL | |
| Validation Set Size (Math) | 1035 | |
| Validation Set Size (Code) | 2000 | |
| Qwen2.5-Math-1.5B | LR | 1 × 10-5 |
| LR Scheduler | cosine | |
| Batch Size | 32 | |
| Epochs | 3 | |
| Maximum Sequence Length | 4096 | |
| Warmup Steps | 5 | |
| Distill Loss Type | KL | |
| Validation Set Size (Math) | 1200 | |
| Validation Set Size (Code) | 2000 |
7. 总结与思考
7.1. 结论总结
本论文提出了 Learning to Focus (LeaF) 框架,这是一种新颖的策略,旨在通过增强大型语言模型 (LLMs) 的上下文依从性来提高其推理能力和生成质量。LeaF 的核心在于利用因果分析和梯度引导的词元剪枝,有效地识别并消除了训练数据中引入虚假关联的混淆词元。通过两阶段的框架,即混淆词元检测和因果注意力蒸馏,LeaF 强制学生模型学习真实的因果依赖关系,并使其注意力与教师模型在真正关键上下文词元上的聚焦分布对齐。
实验结果强有力地证明了 LeaF 的有效性。相较于标准知识蒸馏,LeaF 在数学推理、代码生成和多跳问答等多个领域的基准测试中,取得了显著的绝对准确率提升。此外,案例研究和可视化结果还表明,LeaF 有效抑制了模型对混淆词元的注意力,使得模型在推理时更具可解释性和可靠性。论文还发现,将混淆词元剪枝的范围从指令级别扩展到响应级别,可以带来进一步的性能增益。尽管 LeaF 引入了适度的计算开销,但其辅助过程的离线和并行化特性,以及带来的显著性能提升,证明了其在实际应用中的实用性和可扩展性。
7.2. 局限性与未来工作
论文作者指出了以下局限性:
-
对高级教师模型的依赖性 (Dependence on an advanced teacher model):
LeaF方法的核心在于教师模型能够识别混淆词元。这意味着其性能在一定程度上受限于教师模型的知识和能力。如果教师模型本身就存在偏见或无法识别某些混淆模式,那么学生模型也可能无法完全摆脱这些问题。 -
对长文本生成的有限可扩展性 (Limited scalability to long-form generation): 目前,该方法主要在数学和代码生成任务上进行了验证。对于其他长文本生成任务(如故事创作、长篇摘要等),其适用性和效果仍有待进一步探究。长文本生成可能涉及更复杂的上下文依赖和更难以识别的混淆模式。
基于这些局限性,论文提出了以下未来研究方向:
-
不依赖高级模型实现模型自改进 (Self-improvement mechanisms without relying on advanced models): 探索如何通过模型自身的机制,使其能够自我完善其注意力,识别并修正对关键词元的关注,从而提升推理能力,是一个有前景的方向。这可能涉及无监督或自监督的学习范式。
-
将方法扩展到长文本生成和其他领域 (Applicability to long-text generation and other domains): 研究
LeaF框架在更广泛的任务和领域中的有效性,特别是在长文本生成场景下,如何有效地识别和处理混淆词元,是一个重要的挑战。
7.3. 个人启发与批判
7.3.1. 个人启发
这篇论文提供了一些深刻的启发:
- 因果推断在LLM解释性和鲁棒性上的潜力: 将
LLM的“注意力不集中”问题提升到“因果混淆”的高度,并利用因果干预的思想来解决,这种理论视角非常新颖且有力。它超越了简单的性能提升,深入到模型决策的“原因”层面,为LLM的可解释性研究开辟了新的路径。 - 梯度敏感性作为诊断工具的价值: 教师模型和学生模型之间的梯度敏感性差异,提供了一个直观且量化的方式来诊断学生模型被误导的区域。这种方法比单纯依赖困惑度或模型内部的注意力权重更能揭示模型关注点的“质量”差异。这对于理解为什么小模型会犯错,以及如何修正这些错误具有重要指导意义。
- 对比学习在蒸馏中的精妙应用:
LeaF提出的反事实蒸馏损失,通过在“有混淆”和“无混淆”两种情境下进行学习,本质上是一种对比学习。它巧妙地利用了教师模型在两种情境下的“正确”行为,引导学生模型区分哪些是真正的因果信号,哪些是虚假关联。这种“学会忽略”的能力,对于构建更鲁棒的LLM至关重要。 - 全面干预的重要性: 发现指令级别和响应级别都存在干扰模式,并设计了相应的剪枝策略,这表明在整个生成过程中持续关注和纠正模型注意力是必要的。这提示我们在设计
LLM优化方法时,需要考虑更全面的上下文影响。
7.3.2. 批判
尽管 LeaF 取得了显著进展,但仍有一些潜在问题和可以改进的地方:
-
教师模型的“金标准”问题:
LeaF性能的高度依赖于教师模型的“智慧”。如果教师模型本身就存在对某些模式的虚假关联或理解偏差,那么这些缺陷可能会被传递给学生模型,甚至可能导致新的偏差。教师模型是否总是能提供“真实的因果关系”,这是一个值得探讨的前提。 -
“混淆词元”的定义和普适性: 论文将混淆词元定义为学生模型敏感但教师模型不敏感且移除后都能正确预测的词元。这个定义在数学和代码任务中可能较为明确,但在更开放、更主观的文本生成任务中,如何准确界定“混淆词元”可能会面临挑战。例如,在创意写作中,一些看似冗余的描述可能并非混淆,而是为了风格或情绪渲染。
-
梯度计算的计算成本: 虽然论文提到梯度计算是离线且可并行化的,但在处理超大规模数据集和超长上下文时,为每个词元计算梯度并进行归一化和比较,其计算和存储开销仍然不容忽视,尤其对于资源受限的研究者而言。
-
阈值 的敏感性: 尽管论文进行了敏感性分析并给出了一个相对稳定的值 (0.10),但这个阈值依然需要手动调优。不同任务、不同模型或不同数据分布下,最佳阈值可能有所不同,这增加了方法的复杂性。
-
剪枝对文本流的影响: 区间剪枝虽然优于集体剪枝,但移除输入序列中的连续词元仍然可能在一定程度上破坏文本的自然流和句法结构,这是否会对模型理解剩余上下文造成微弱但持续的负面影响?论文并未深入探讨。
-
解释性评估的局限: 虽然案例研究提供了直观的可解释性证据,但对“注意力聚焦”和“因果依赖”的量化评估仍然是挑战。目前的方法主要通过下游任务性能间接验证,未来可以探索更直接、更全面的可解释性评估指标。
总的来说,
LeaF提供了一个强大而富有洞察力的方法来解决LLM的注意力聚焦问题。其因果推断的视角和梯度引导的干预机制为LLM的鲁棒性和可解释性研究树立了新的标杆。未来的工作可以在减少教师模型依赖、提升普适性以及优化计算效率方面进一步探索。
相似论文推荐
基于向量语义检索推荐的相关论文。