Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning
TL;DR 精炼摘要
本文针对大语言模型后训练中监督微调与偏好学习顺序执行导致的灾难性遗忘问题,理论证明了其次优性,提出了联合后训练框架及算法,实现两阶段平衡且训练效率无损,显著提升模型的综合性能和安全性。
摘要
Post-training of pre-trained LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning (RLHF or DPO) stage, is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, sequential training is sub-optimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. We theoretically prove the sub-optimality of sequential post-training. Furthermore, we propose a practical joint post-training framework with theoretical convergence guarantees and empirically outperforms sequential post-training framework, while having similar computational cost. Our code is available at https://github.com/heshandevaka/XRIGHT.
思维导图
论文精读
中文精读
1. 论文基本信息 (Bibliographic Information)
- 标题 (Title): Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning (在 LLM 监督微调和偏好学习中缓解遗忘问题)
- 作者 (Authors): Heshan Fernando, Han Shen, Parikshit Ram, Yi Zhou, Horst Samulowitz, Nathalie Baracaldo, Tianyi Chen. 作者分别来自伦斯勒理工学院 (Rensselaer Polytechnic Institute) 和 IBM 研究院 (IBM Research)。
- 发表期刊/会议 (Journal/Conference): 本文是一篇提交到 arXiv 的预印本论文,版本为 v3。arXiv 是一个公开的学术论文预印本平台,在该领域的论文通常会先发布于此,以快速分享研究成果。
- 发表年份 (Publication Year): 2024
- 摘要 (Abstract): 预训练大语言模型(LLM)的后训练(Post-training)对于实现有效和安全的 LLM 应用至关重要,该过程通常包括监督微调(SFT)和偏好学习(RLHF 或 DPO)两个阶段。目前在主流开源 LLM 的后训练中,广泛采用的方法是顺序执行 SFT 和 RLHF/DPO。然而,顺序训练在 SFT 和 RLHF/DPO 的权衡方面是次优的:当 LLM 进入第二阶段训练时,会逐渐忘记第一阶段的训练内容。我们从理论上证明了顺序后训练的次优性。此外,我们提出了一个实用的联合后训练框架,该框架具有理论上的收敛保证,并且在实验中表现优于顺序后训练框架,同时计算成本相近。
- 原文链接 (Source Link):
- ArXiv 链接: https://arxiv.org/abs/2410.15483v3
- PDF 链接: https://arxiv.org/pdf/2410.15483v3.pdf
- 发布状态:预印本 (Preprint)
2. 整体概括 (Executive Summary)
-
研究背景与动机 (Background & Motivation - Why):
- 核心问题: 大语言模型(LLM)的后训练通常包含两个关键阶段:监督微调 (Supervised Fine-Tuning, SFT),用于教会模型遵循指令;以及偏好学习 (Preference Learning),如
DPO或RLHF,用于使模型行为与人类偏好对齐。目前行业标准做法是顺序训练 (Sequential Training),即先完成一个阶段再开始另一个。这种方法存在一个严重的缺陷:灾难性遗忘 (Catastrophic Forgetting)。当模型在第二阶段(如 SFT)进行优化时,它会逐渐丧失在第一阶段(如 DPO)学到的能力,导致最终模型在两个目标之间无法取得良好的平衡。 - 问题重要性: SFT 和偏好学习对于构建有用且安全的 LLM 都至关重要。如果模型在微调后忘记了偏好对齐,它可能会重新生成有害或无用的内容;反之,如果过度对齐而忘记了通用指令遵循能力,其应用范围会受限。因此,如何高效地平衡这两个目标是一个关键挑战。
- 现有挑战与空白: 一个简单的想法是同时优化两个目标的混合损失 (Mixed Loss),但这种“朴素混合” (
naive mixing) 方法在实践中计算成本极高,因为它需要为每个目标构建独立的计算图,对于 LLM 这样的大模型来说,内存和时间开销是无法接受的。 - 切入点与创新思路: 本文的核心问题是:我们能否设计一个后训练框架,它既能比顺序训练取得更好的性能权衡,又比朴素混合方法的计算成本低得多?
- 核心问题: 大语言模型(LLM)的后训练通常包含两个关键阶段:监督微调 (Supervised Fine-Tuning, SFT),用于教会模型遵循指令;以及偏好学习 (Preference Learning),如
-
核心贡献/主要发现 (Main Contribution/Findings - What):
- C1) 理论上揭示顺序训练的“遗忘”问题: 首次从理论上证明了顺序执行 DPO 和 SFT 训练是次优的,并指出其优化差距 (
optimality gap) 不会随训练时间的增加而消失(即存在一个恒定的性能下限)。这为“遗忘”问题提供了坚实的数学依据。 - C2) 提出两种高效的联合训练方法:
- ALRIGHT (ALternating supeRvised fInetuninG and Human preference alignmenT): 一种交替优化算法,在每次迭代中随机选择优化 DPO 或 SFT 目标。该方法有理论收敛保证,可以达到任意期望的性能权衡。
- MAXRIGHT (MAXimum supeRvised fIne-tuninG and Human preference alignmenT): 一种更智能的交替优化算法,在每次迭代中自适应地选择当前表现较差的目标进行优化。该方法在实践中能更快地收敛到理想的性能点。
- 关键优势: 这两种方法都实现了与顺序训练相近的低计算成本,远低于朴素混合法。
- C3) 强大的实证效果: 在
LLAMA3-8B等模型上的实验表明,与顺序训练相比,新方法在 MMLU 基准测试中最高提升 3%,在 RLHF 偏好评估中胜率(Win rate)最高提升 31%,且几乎没有额外的计算开销。
- C1) 理论上揭示顺序训练的“遗忘”问题: 首次从理论上证明了顺序执行 DPO 和 SFT 训练是次优的,并指出其优化差距 (
3. 预备知识与相关工作 (Prerequisite Knowledge & Related Work)
-
基础概念 (Foundational Concepts):
- 监督微调 (Supervised Fine-Tuning, SFT): 这是训练 LLM 遵循指令的第一步。研究人员会收集大量高质量的“指令-回答”对,然后像训练普通语言模型一样,使用这些数据对预训练好的 LLM 进行微调。其目标是最大化模型生成正确回答的对数概率,这通常通过一个负对数似然损失 (Negative Log-Likelihood Loss) 来实现。
- 偏好学习 (Preference Learning): SFT 之后,为了让模型的回答更符合人类的偏好(例如,更有用、更无害、更诚实),需要进行偏好学习。这通常需要包含人类偏好信息的数据集,例如,对于同一个问题,给出两个回答,并标注哪一个“更好”(chosen),哪一个“更差”(rejected)。
- 从人类反馈中进行强化学习 (Reinforcement Learning from Human Feedback, RLHF): 经典的三阶段偏好学习方法:1) 训练一个 SFT 模型;2) 用偏好数据训练一个奖励模型 (Reward Model),让它学会给回答打分;3) 使用强化学习算法(如
PPO)来微调 SFT 模型,使其生成的回答能在奖励模型上获得高分。这个过程复杂且不稳定。 - 直接偏好优化 (Direct Preference Optimization, DPO): 一种更简单、更稳定的偏好学习方法。它跳过了奖励建模和强化学习的复杂步骤,直接使用偏好数据(更好/更差的回答对)来优化 LLM。其核心思想是通过一个简单的二元交叉熵损失 (Binary Cross-entropy Loss),直接增加模型生成“更好”回答的概率,同时降低生成“更差”回答的概率。DPO 的损失函数中包含一个
参考模型(),用于约束模型不要偏离其初始能力太远,以及一个温度系数(),用于控制对齐的强度。 - 帕累托最优 (Pareto Optimality): 在多目标优化问题中,一个解如果不能在不损害任何一个目标的情况下改善其他任何一个目标,那么这个解就是帕累托最优的。所有帕累托最优解的集合构成了帕累托前沿 (Pareto Front)。在本文中,这意味着在 SFT 和 DPO 两个目标之间找到一系列“最佳”的权衡点。
-
前人工作 (Previous Works):
- 顺序训练 (Sequential Training): 这是当前开源社区(如
LLAMA-3,PHI-3)普遍采用的后训练流程,即先进行 SFT,再进行 DPO/RLHF,或者反之。本文的主要批判对象就是这种方法。 - 问题观察: 已有研究(如
Qi et al., 2023)观察到,在对齐后继续进行 SFT 会损害对齐效果,证实了“遗忘”现象的存在。 - 朴素混合 (Naive Mixing): 将 SFT 和 DPO 的损失函数加权求和,然后进行优化。这种方法虽然直观,但因需要同时处理两种不同格式的数据和计算图,导致计算成本(特别是内存)急剧增加,不适用于大型模型。
- 自适应模型平均 (Adaptive Model Averaging, AMA): (Lin et al., 2023) 分别优化 SFT 和 RLHF 模型,然后进行模型参数的加权平均。这种方法计算成本高昂,需要维护三组模型参数。
- 专用数据方法: (Yang et al., 2024; Guo et al., 2024) 通过在提示中编码不同的偏好来修改训练数据,以控制 SFT 和 RLHF 的权衡。
- 顺序训练 (Sequential Training): 这是当前开源社区(如
-
差异化分析 (Differentiation): 与上述工作相比,本文提出的
ALRIGHT和MAXRIGHT方法的核心创新在于:- 高效性: 它们实现了与朴素混合法相当的性能权衡,但计算成本却与低效的顺序训练法持平。
- 通用性: 它们不需要像某些方法那样对数据集进行特殊构造或修改,可以直接应用于标准的 SFT 和 DPO 数据集。
- 理论支撑: 不仅提出了算法,还从理论上证明了顺序训练的缺陷和
ALRIGHT算法的收敛性,使方法更具说服力。
4. 方法论 (Methodology - Core Technology & Implementation Details)
本部分详细拆解论文中关于顺序训练的分析以及提出的两种新方法。
-
问题形式化 (Problem Formulation):
-
DPO 目标函数: 设模型参数为 ,DPO 的目标是最小化以下损失函数 : 符号解释:
- : 一条偏好数据,其中 是输入提示, 是被偏好的回答 (winner), 是被拒绝的回答 (loser)。
- : 当前模型 生成回答 的概率。
- : 一个固定的参考模型生成回答 的概率,用于正则化。
- : 温度系数,控制对齐强度。
- : Sigmoid 函数。
- 该公式的直观含义是:让模型 相对于参考模型 更倾向于生成 而不是 。
-
SFT 目标函数: SFT 的目标是最小化标准的负对数似然损失 : 符号解释:
(x, y): 一条指令数据,其中 是输入指令, 是期望的输出。- 该公式的含义是:最大化模型生成正确回答 的概率。
-
性能权衡指标 (Performance Trade-off Metric): 为了评估两种目标的平衡,论文定义了一个混合目标 : 其中 是一个权衡系数。 越大,表示越重视 DPO 性能。一个好的算法应该能够针对不同的 找到使该混合目标最小的解。
-
-
顺序训练的次优性分析 (Sub-optimality of Sequential Training):
-
算法流程: 如 Algorithm 1 所示,该方法分为两个阶段:先用 DPO 数据集训练 步,得到模型 ;然后将 作为初始模型,再用 SFT 数据集训练 步。
-
理论证明 (Theorem 3.3): 论文通过构建一个特定的数据分布例子,从理论上证明了对于任何权衡系数 ,当训练步数 足够大时,顺序训练得到的模型 与理想最优解之间的差距是一个常数,即: 这从数学上表明,顺序训练由于“遗忘”效应,永远无法收敛到真正的帕累托最优解,其性能存在一个无法通过增加训练时长来消除的硬伤。
该图像是论文中图5的示意图,展示了定理3.3下界推导所用的例子。左图是函数、及其线性组合函数的曲线,右图展示了随训练轮次变化的误差。
-
-
方法一: ALRIGHT (交替优化)
- 核心思想: 与其在两个阶段中分别只关注一个目标,不如在每次迭代中随机选择一个目标进行优化。这样,两个目标都能在整个训练过程中被持续关注,从而缓解遗忘。
- 算法流程 (Algorithm 2):
- 在每次训练迭代 开始时,从一个伯努利分布 中采样。
- 如果 (以概率 发生),则从 DPO 数据集中采样,并执行一步 DPO 梯度更新。
- 如果 (以概率 发生),则从 SFT 数据集中采样,并执行一步 SFT 梯度更新。
- 理论保证 (Theorem 4.1): 论文证明了
ALRIGHT算法的优化差距会随着训练步数 的增加而趋近于 0: 这个结果表明,与顺序训练不同,ALRIGHT可以收敛到任意期望的权衡点,从根本上解决了次优性问题。
-
方法二: MAXRIGHT (自适应最大化优化)
- 核心思想: 随机交替可能不是最高效的方式。一个更智能的策略是优先优化当前表现更差的目标。
MAXRIGHT正是基于这一思想,它自适应地选择能最大程度降低整体“加权次优性”的目标进行更新。 - 算法流程 (Algorithm 3):
- 在每次迭代 ,分别计算 DPO 和 SFT 目标的“加权次优性”:
- 其中 和 是各自目标的理论最优值(实践中需要预先估计)。
- 比较两者大小,如果 ,则说明 DPO 目标的“相对表现”更差,执行一步 DPO 更新。
- 否则,执行一步 SFT 更新。
- 在每次迭代 ,分别计算 DPO 和 SFT 目标的“加权次优性”:
- 实践中的优化: 每次迭代都计算两个损失会增加开销。论文提出了一种内存高效的变体:每隔 步才同时评估两个损失,在中间的
k-1步,则基于“过时”的损失值来选择更新目标,并只更新被选中的那个损失的“过时值”。这在保持自适应性的同时,将计算成本降至与ALRIGHT相当。
- 核心思想: 随机交替可能不是最高效的方式。一个更智能的策略是优先优化当前表现更差的目标。
5. 实验设置 (Experimental Setup)
- 数据集 (Datasets):
- DPO 数据集:
DAHOAs/RM-HH-RLHF,这是一个包含人类反馈的数据集,旨在将模型与人类偏好对齐。 - SFT 数据集:
VICGALLE/ALPACA-GPT4,这是一个由 GPT-4 生成的英语指令遵循数据集,用于微调 LLM 的指令能力。
- DPO 数据集:
- 模型 (Models):
PYTHIA-1B: 一个中等规模的模型,用于分析优化动态、性能权衡和资源使用情况。LLAMA3-8B: 一个更大、更先进的模型,用于评估在真实下游任务上的表现。- 微调技术: 所有实验均采用
LoRA(Low-Rank Adaptation) 进行参数高效微调,以降低计算资源需求。
- 评估指标 (Evaluation Metrics):
- DPO 最优性差距 (DPO Optimality Gap):
- 概念定义: 该指标衡量当前模型在 DPO 任务上的损失值与理论上的最优 DPO 损失值之间的差距。它直接反映了模型在偏好对齐方面的性能,值越小表示对齐得越好。
- 数学公式:
- 符号解释: 是当前模型 的 DPO 损失; 是通过单独、长时间优化 DPO 目标得到的近似最优值。
- SFT 最优性差距 (SFT Optimality Gap):
- 概念定义: 衡量当前模型在 SFT 任务上的损失值与理论最优 SFT 损失值之间的差距。它反映了模型遵循指令的能力,值越小表示能力越强。
- 数学公式:
- 符号解释: 是当前模型 的 SFT 损失; 是通过单独、长时间优化 SFT 目标得到的近似最优值。
- 理想距离 (Ideal Distance):
- 概念定义: 该指标衡量在二维损失空间中,当前模型的性能点
(SFT 损失, DPO 损失)与两个目标都达到最优的“理想点”(SFT 最优损失, DPO 最优损失)之间的欧几里得距离。这是一个综合性指标,距离越小,说明模型在两个目标上的综合表现越接近完美。 - 数学公式:
- 符号解释: 各符号含义同上。
- 概念定义: 该指标衡量在二维损失空间中,当前模型的性能点
- MMLU (1-shot):
- 概念定义: 一个大规模多任务语言理解基准,包含 57 个不同学科(如数学、历史、法律、计算机科学等)的选择题。它用于评估模型的知识广度和问题解决能力。
1-shot表示在测试时,会给模型一个包含正确答案的例子作为提示。
- 概念定义: 一个大规模多任务语言理解基准,包含 57 个不同学科(如数学、历史、法律、计算机科学等)的选择题。它用于评估模型的知识广度和问题解决能力。
- 胜率 (Win rate):
- 概念定义: 这是一种基于模型间直接比较的评估方法。将本文方法训练出的模型与基线模型生成的回答放在一起,让一个强大的“裁判”模型(如
GPT-4-TURBO)判断哪个更好。胜率即本文模型被判为“胜利”的次数占总比较次数的比例。它直接衡量了模型回答是否更符合人类(或强大 AI)的偏好。
- 概念定义: 这是一种基于模型间直接比较的评估方法。将本文方法训练出的模型与基线模型生成的回答放在一起,让一个强大的“裁判”模型(如
- DPO 最优性差距 (DPO Optimality Gap):
- 对比基线 (Baselines):
- Sequential (顺序训练): 先 DPO 后 SFT,这是本文要批判和超越的主要基线。
- Mix (朴素混合): 同时优化 DPO 和 SFT 的加权损失。它被用作性能的“上限”参考,但计算成本高昂。
6. 实验结果与分析

*该图像是论文中用于比较不同后训练策略(Sequential、ALRIGHT、Mix、MAXRIGHT)在DPO和SFT最优性差距、理想距离、运行时间和GPU利用率提升方面表现的图表。左侧曲线图展示了与关系;右侧条形图对比各指标。*
-
核心结果分析 (Core Results Analysis):
ALRIGHT和MAXRIGHT实现了更好的性能权衡: 从上图左侧的优化轨迹可以看出,Sequential方法(绿色曲线)在 SFT 阶段会严重“忘记”DPO 目标,导致最终模型(绿色十字星)偏离 DPO 最优点很远。相比之下,ALRIGHT(红色星)和Mix(紫色星)能够形成一条清晰的帕累托前沿,表明它们能在两个目标间取得良好的权衡。MAXRIGHT表现最接近理想点:MAXRIGHT(蓝色星)的最终落点在所有方法中最靠近左下角的“理想点”(黑色星),这在上图右侧的Ideal Distance条形图中也得到了证实。这说明其自适应策略能更有效地平衡两个目标。ALRIGHT和MAXRIGHT的计算成本极低: 上图右下角的资源使用图显示,ALRIGHT和MAXRIGHT的运行时间和 GPU 使用率相比Sequential几乎没有增加(甚至有时为负,可能是因为实现细节的差异)。而Mix方法的运行时间增加了超过 50%,GPU 使用率增加了超过 35%,成本极高。这证明了本文方法在实现“更好权衡”的同时做到了“低成本”。
-
消融实验/参数分析 (Ablation Studies / Parameter Analysis):
该图像是论文中图4,展示了使用Pythia-1b模型时不同最大评估步骤(Max Eval. Steps)对MAXRIGHT性能的影响。左侧四个子图分别对应1、10、100和1000步,横纵轴为和,星形标记不同值。右侧柱状图分别比较了DPO和SFT的最优差距、理想距离、运行时间和GPU利用率的变化。MAXRIGHT的评估步数 对性能的影响: 上图展示了内存高效MAXRIGHT中,“最大评估步数” 的选择至关重要。- 当 时(每次都评估),算法轨迹平滑,能很好地逼近理想点,但运行时间成本最高。
- 当 增大时(如 10 或 100),成本降低,但轨迹开始出现小幅振荡。
- 当 过大时(如 1000),算法几乎退化为随机或分块更新,轨迹剧烈振荡,完全偏离了理想的权衡路径,理想距离也变得很差。这说明 是一个需要在效率和性能之间进行权衡的关键超参数。
-
下游任务性能 (LLAMA3-8B):
以下是论文中 Table 1 的转录结果,展示了在
LLAMA3-8B模型上的真实任务表现。| | \multicolumn{3}{c}{MMLU (1-shot) (%)} | \multicolumn{3}{c}{Win rate (%)} | :---------- | :--- | :--- | :--- | :--- | :--- | :--- | λ/(T_SFT, T_DPO) | 0.25/(3,1) | 0.5/(2, 2) | 0.75/(1, 3) | 0.25/(3, 1) | 0.5/(2,2) | 0.75/(1, 3) | Sequential | 73.18 | 72.80 | 72.68 | 57.19 | 65.62 | 59.38 | Mix | 73.45 | 73.40 | 72.29 | 81.88 | 84.22 | 88.42 | ALRIGHT | 74.66 | 72.65 | 75.50 | 88.28 | 85.78 | 87.34 | MAXRIGHT | 72.35 | 73.42 | 74.24 | 86.56 | 86.09 | 83.75
- 分析: 从表格数据可以看出,
ALRIGHT和MAXRIGHT在 MMLU 知识问答和 Win rate 偏好对齐两个维度的评估中,几乎全面优于Sequential基线。例如,当 SFT 权重较大时(λ=0.25),ALRIGHT的 MMLU 得分比Sequential高出 1.48%;当 DPO 权重较大时(λ=0.75),高出 2.82%。在胜率方面,提升更为显著,ALRIGHT和MAXRIGHT的胜率普遍在 80% 以上,远高于Sequential的 50-60% 区间。这强有力地证明了新方法在实际应用中的优越性。
- 分析: 从表格数据可以看出,
7. 总结与思考 (Conclusion & Personal Thoughts)
-
结论总结 (Conclusion Summary): 本文有力地论证了在 LLM 后训练中广泛采用的顺序训练范式存在根本性的“遗忘”缺陷,并从理论和实践上证明了其次优性。作为替代方案,论文提出了两种创新、高效且低成本的联合训练方法:
ALRIGHT和MAXRIGHT。ALRIGHT通过随机交替优化,保证了理论上的收敛性;而MAXRIGHT通过自适应地选择优化目标,在实践中取得了更接近理想的性能。实验结果表明,这两种方法不仅在 SFT 和 DPO 目标上取得了远优于顺序训练的平衡,而且在下游任务中也表现出显著的性能提升,最重要的是,这一切几乎没有增加额外的计算成本。 -
局限性与未来工作 (Limitations & Future Work):
- 最优值估计:
MAXRIGHT方法依赖于对 和 的预先估计,这本身需要额外的计算步骤,并且估计的准确性可能会影响算法的性能。 - 理论模型简化: 论文的理论分析基于一个简化的
softmax模型。虽然实验结果验证了其结论,但该理论是否能完全捕捉真实 Transformer 模型的复杂动态仍有待进一步研究。 - 算法适用范围: 本文主要关注 DPO,但这些联合训练的思想是否能直接推广到更复杂的偏好学习算法,如经典的 RLHF(使用 PPO),是一个值得探索的方向。
- 多目标扩展: 实际应用中可能需要平衡更多目标,例如 SFT、偏好对齐、安全性、知识真实性等。如何将
ALRIGHT和MAXRIGHT的思想扩展到三个或更多目标的多目标优化场景,是一个有价值的未来研究方向。
- 最优值估计:
-
个人启发与批判 (Personal Insights & Critique):
- 实用价值极高: 这篇论文解决了一个非常普遍且痛点的工程问题。许多团队在实践中都可能遇到“微调后模型变笨”或“对齐后模型能力下降”的困扰。论文提出的方法简单、有效且易于实现,为解决这一问题提供了“即插即用”的方案。
- 思想的启发性:
ALRIGHT和MAXRIGHT的核心思想——用“交替”代替“混合”来降低多目标优化的成本——非常巧妙。特别是MAXRIGHT的自适应机制,体现了从“盲目轮换”到“智能调度”的进步,这种思想可以迁移到其他多任务学习或持续学习的场景中。 - 对现有范式的挑战: 本文的工作有力地挑战了“分阶段、顺序式”的 LLM 训练思维定式,倡导一种更整体、更动态的联合优化视角。这提醒我们,在复杂的 AI 系统工程中,孤立地优化各个模块可能会导致全局次优,系统性的联合设计至关重要。
- 潜在的改进方向:
MAXRIGHT中对 值的选择目前似乎依赖经验,未来或许可以设计一种自适应调整 值的策略,使其在训练的不同阶段动态平衡效率与性能。此外,选择优化目标的标准也可以从“损失值”扩展到“梯度范数”或“梯度相似度”等更丰富的指标。
相似论文推荐
基于向量语义检索推荐的相关论文。