PEARL: Towards Permutation-Resilient LLMs
TL;DR 精炼摘要
本文提出排列弹性学习(PEARL),通过分布鲁棒优化和排列提议网络结合最优传输算法,针对最坏排列情况提升大型语言模型的鲁棒性。PEARL有效抵御排列攻击,并在多样上下文场景下显著增强模型性能和泛化能力。
摘要
The in-context learning (ICL) capability of large language models (LLMs) enables them to perform challenging tasks using provided demonstrations. However, ICL is highly sensitive to the ordering of demonstrations, leading to instability in predictions. This paper shows that this vulnerability can be exploited to design a natural attack - difficult for model providers to detect
- that achieves nearly 80% success rate on LLaMA-3 by simply permuting the demonstrations. Existing mitigation methods primarily rely on post-processing and fail to enhance the model's inherent robustness to input permutations, raising concerns about safety and reliability of LLMs. To address this issue, we propose Permutation-resilient learning (PEARL), a novel framework based on distributionally robust optimization (DRO), which optimizes model performance against the worst-case input permutation. Specifically, PEARL consists of a permutation-proposal network (P-Net) and the LLM. The P-Net generates the most challenging permutations by treating it as an optimal transport problem, which is solved using an entropy-constrained Sinkhorn algorithm. Through minimax optimization, the P-Net and the LLM iteratively optimize against each other, progressively improving the LLM's robustness. Experiments on synthetic pre-training and real-world instruction tuning tasks demonstrate that PEARL effectively mitigates permutation attacks and enhances performance. Notably, despite being trained on fewer shots and shorter contexts, PEARL achieves performance gains of up to 40% when scaled to many-shot and long-context scenarios, highlighting its efficiency and generalization capabilities.
思维导图
论文精读
中文精读
1. 论文基本信息
1.1. 标题
PEARL: Towards Permutation-Resilient LLMs
1.2. 作者
Liang Chen, Li Shen, Yang Deng, Xiaoyan Zhao, Bin Liang, Kam-Fai Wong。 作者主要来自香港中文大学 (The Chinese University of Hong Kong)、中山大学深圳校区 (Shenzhen Campus of Sun Yat-sen University) 和新加坡管理大学 (SMU)。
1.3. 发表期刊/会议
该论文发布于 arXiv 预印本平台,尚未明确说明最终发表的期刊或会议。arXiv 是一个开放获取的预印本服务器,允许研究者在正式同行评审前分享其研究成果。
1.4. 发表年份
2025年2月20日 (UTC)
1.5. 摘要
大型语言模型 (LLMs) 的上下文学习 (In-Context Learning, ICL) 能力使其能够利用提供的示例来执行复杂的任务。然而,ICL 对示例的排列顺序高度敏感,导致预测不稳定。本文展示了这种脆弱性可以被利用来设计一种难以被模型提供商检测到的自然攻击——通过简单地重新排列示例,在 LLaMA-3 上实现了近 80% 的成功率。现有的缓解方法主要依赖于后处理,未能增强模型对输入排列的固有鲁棒性,从而引发了对 LLMs 安全性和可靠性的担忧。为了解决这个问题,我们提出了排列弹性学习 (Permutation-resilient learning, PEARL),这是一种基于分布鲁棒优化 (Distributionally Robust Optimization, DRO) 的新颖框架,它针对最坏情况的输入排列优化模型性能。具体来说,PEARL 由一个排列提议网络 (Permutation-Proposal Network, P-Net) 和 LLM 组成。P-Net 通过将生成最具挑战性的排列视为一个最优传输 (Optimal Transport, OT) 问题来解决,该问题使用熵约束 Sinkhorn 算法 (entropy-constrained Sinkhorn algorithm) 求解。通过最小最大优化 (minimax optimization),P-Net 和 LLM 迭代地相互优化,逐步提高 LLM 的鲁棒性。在合成预训练和真实世界指令微调任务上的实验表明,PEARL 有效地缓解了排列攻击并增强了性能。值得注意的是,尽管在较少示例 (fewer shots) 和较短上下文 (shorter contexts) 下进行训练,PEARL 在扩展到多示例 (many-shot) 和长上下文 (long-context) 场景时,性能增益高达 40%,这突出了其效率和泛化能力。
1.6. 原文链接
原文链接: https://arxiv.org/abs/2502.14628 PDF 链接: https://arxiv.org/pdf/2502.14628v1.pdf 发布状态: 预印本 (Preprint)
2. 整体概括
2.1. 研究背景与动机
2.1.1. 论文试图解决的核心问题是什么?
论文试图解决的核心问题是大型语言模型 (LLMs) 在上下文学习 (In-Context Learning, ICL) 中对输入示例排列顺序的敏感性,这种敏感性导致模型预测的不稳定性,并可能被恶意利用进行攻击。
2.1.2. 为什么这个问题在当前领域是重要的?现有研究存在哪些具体的挑战或空白(Gap)?
- 重要性: ICL 是 LLMs 的一项关键能力,但其对排列的敏感性使得提示工程(
prompt engineering)变得困难,降低了模型在实际应用中的可靠性和安全性。论文发现,这种敏感性甚至可以被设计成一种难以检测的攻击,在先进的 LLaMA-3 模型上成功率高达 80%,这凸显了其作为一个严重漏洞的现实威胁。 - 现有挑战与空白:
- 现有方法不足: 大多数现有的 ICL 研究专注于提高平均性能,而对排列鲁棒性关注有限。
- 后处理限制: 现有的缓解方法(如输出校准、示例顺序优化)主要依赖于推理阶段的后处理技术,这会引入额外的计算开销,并且通常无法从根本上增强模型对输入排列的固有鲁棒性。它们没有解决 LLM 自身在训练阶段就存在的问题。
- 架构修改的局限性: 少数尝试通过修改模型架构(如单向注意力、排列等变架构)来增强鲁棒性的方法,往往缺乏可扩展性或通用性。
- ERM 的局限性: 传统的经验风险最小化 (Empirical Risk Minimization, ERM) 训练范式关注平均性能,而忽视了最坏情况的性能,导致模型在未见过的排列上表现不佳。
2.1.3. 这篇论文的切入点或创新思路是什么?
论文的创新切入点在于从分布鲁棒优化 (DRO) 的角度出发,提出了一种训练阶段 (training-stage) 的解决方案,旨在增强 LLM 对输入排列的固有鲁棒性 (inherent robustness),而不是依赖于推理阶段的修补。它将“寻找最差排列”这一挑战性问题形式化为一个最优传输问题,并通过对抗性训练来解决。
2.2. 核心贡献/主要发现
2.2.1. 论文最主要的贡献是什么?
- 揭示并量化了先进 LLM (LLaMA-3) 的排列脆弱性: 论文通过实验证明,即使是 SOTA(
state-of-the-art)的 LLM 也极易受到简单的排列攻击,成功率高达 80%,强调了这一问题的严重性和隐蔽性。 - 提出了 PEARL 框架: 引入了一种新颖的、基于分布鲁棒优化 (DRO) 的排列弹性学习框架,从根本上提升 LLM 对输入排列的鲁棒性。这是首次将 DRO 应用于解决 LLM 的 ICL 排列敏感性问题。
- 设计了 P-Net 和最优传输方法: 提出了一个排列提议网络 (P-Net),通过将生成对抗性排列视为最优传输 (OT) 问题,并利用熵约束 Sinkhorn 算法高效地找到最挑战性的排列。
- 采用最小最大优化策略: 通过对抗性训练,P-Net 和 LLM 迭代地相互优化,从而在训练阶段增强 LLM 对最坏情况排列的性能。
2.2.2. 论文得出了哪些关键的结论或发现?
- 有效缓解排列攻击: PEARL 显著提高了 LLM 对排列攻击的防御能力,并在合成预训练和真实世界指令微调任务中持续提升了平均和最坏情况下的性能。
- 显著的性能增益: 相比基线方法,PEARL 不仅提高了平均性能,更在最坏情况性能上取得了显著提升(例如,在某些任务上,针对最坏情况的性能提升高达 29.4%)。
- 高效性和泛化能力: 尽管在较少示例和较短上下文下训练,PEARL 在扩展到多示例 (many-shot) 和长上下文 (long-context) 场景时,性能增益可达 24% 到 40%,并能推广到不同的 LLM 家族(如 Mistral, Gemma, Llama),展现出强大的泛化能力和效率。
- 更强的样本效率 (shot efficiency): PEARL 训练的模型在达到与基线模型相同的平均性能时,所需的示例数量减少了 2 到 4 倍。
- 额外发现: 增加示例数量虽然通常会提高平均性能,但却可能大幅恶化最坏情况性能,因为排列组合数量呈指数增长,增加了模型遇到“坏”排列的风险。同时,LLM 家族对排列的敏感性不同,Llama 模型最为敏感。
3. 预备知识与相关工作
3.1. 基础概念
3.1.1. 上下文学习 (In-Context Learning, ICL)
In-Context Learning (ICL) 是大型语言模型 (LLMs) 的一种独特能力,它允许模型在不进行权重更新的情况下,仅仅通过在输入提示 (prompt) 中提供少量示例(demonstrations)来学习并执行新的任务。这些示例通常以 (输入, 输出) 对的形式给出,模型根据这些示例的模式来理解任务,并为给定的新查询生成相应的输出。ICL 使得 LLMs 能够像人类一样从少量示例中进行推理和泛化,是 LLMs 区别于传统监督学习模型的重要特性。
3.1.2. 大型语言模型 (Large Language Models, LLMs)
Large Language Models (LLMs) 是指参数量巨大(通常数十亿甚至数万亿)、在海量文本数据上预训练的深度学习模型。它们能够理解、生成和处理人类语言,执行各种自然语言处理 (NLP) 任务,例如问答、文本摘要、翻译、代码生成等。LLMs 的核心架构通常基于 Transformer。
3.1.3. 排列敏感性 (Permutation Sensitivity)
Permutation Sensitivity 指的是 LLMs 在 ICL 过程中,对提供给模型的示例(demonstrations)的排列顺序非常敏感。即使示例的实际内容不变,仅仅改变它们的顺序,也可能导致模型输出的显著变化,甚至性能大幅下降。这种敏感性是 LLMs 在实际应用中面临的一个重要挑战,影响了其预测的稳定性和可靠性。
3.1.4. 分布鲁棒优化 (Distributionally Robust Optimization, DRO)
Distributionally Robust Optimization (DRO) 是一种优化框架,旨在使模型在面对数据分布不确定性时仍能表现良好。与传统的经验风险最小化 (ERM) 仅优化训练数据的平均性能不同,DRO 考虑一个“模糊集” (ambiguity set),该集合包含了所有可能的、与经验分布“接近”的潜在数据分布。DRO 的目标是在这个模糊集内,找到能够最小化最坏情况(即导致最大损失的分布)下风险的模型参数。通过这种方式,DRO 训练出的模型对训练数据之外的分布变化具有更强的鲁棒性。
3.1.5. 最优传输 (Optimal Transport, OT)
Optimal Transport (OT) 是数学和优化领域的一个分支,最初由蒙日 (Monge) 在 18 世纪提出,并由坎托罗维奇 (Kantorovich) 在 20 世纪推广。它的核心思想是找到将一个概率分布“传输”到另一个概率分布的“最优”方式,使得传输成本最小。这个成本通常由一个距离函数定义。在机器学习中,OT 被用于比较和匹配不同的数据分布,例如在图像处理、领域适应和生成模型中。
3.1.6. Sinkhorn 算法 (Sinkhorn Algorithm)
Sinkhorn Algorithm 是一种迭代算法,用于近似求解离散 最优传输 (OT) 问题,特别是用于寻找近似双随机矩阵 (doubly stochastic matrix)。双随机矩阵的行和列之和都为 1。在 OT 中,它通常用于将一个给定矩阵(例如成本矩阵的指数)转化为一个双随机矩阵,这个矩阵可以被解释为两个分布之间的软分配(soft assignment)或传输计划。该算法通过交替地对矩阵的行和列进行归一化直到收敛来实现。其优点在于计算效率高,并且可以通过熵正则化使其数值稳定且可微分。
3.1.7. 最小最大优化 (Minimax Optimization)
Minimax Optimization 是一种优化问题,其目标是最小化一个函数的最坏情况值。这通常涉及两个参与者,一个试图最大化某个目标函数(“攻击者”),另一个则试图最小化这个目标函数(“防御者”)。在本文的 PEARL 框架中,P-Net 扮演“攻击者”的角色,试图找到使 LLM 损失最大的排列;而 LLM 扮演“防御者”的角色,试图在 P-Net 找到的最坏排列下最小化自己的损失。这种对抗性训练(adversarial training)使得模型能够更好地应对最坏情况。
3.1.8. 经验风险最小化 (Empirical Risk Minimization, ERM)
Empirical Risk Minimization (ERM) 是机器学习中训练模型的标准方法。它的目标是找到一组模型参数,使得模型在训练数据集上的平均损失 (loss) 最小化。假设训练数据是从真实数据分布中独立同分布地抽样而来,ERM 旨在通过最小化经验风险来近似最小化真实风险。然而,ERM 并不直接关注模型在训练数据中未充分代表的边缘情况或最坏情况下的表现,这可能导致模型在遇到分布偏移或对抗性输入时缺乏鲁棒性。
3.1.9. Gumbel-Sinkhorn / Gumbel Trick
Gumbel-Sinkhorn 是结合了 Gumbel-Softmax 技巧和 Sinkhorn 算法的一种方法,用于在离散变量上进行可微分的采样。Gumbel-Softmax (Gumbel Trick) 是一种技术,它允许从离散分布中进行近似可微分采样,通常用于神经网络中。通过引入 Gumbel 噪声并使用温度参数,它可以在反向传播时计算梯度,从而优化离散决策过程。在 PEARL 中,它被用于从 P-Net 生成的近似双随机矩阵中采样出具体的排列矩阵,使得整个排列生成过程是可微分的,从而能够通过梯度下降来训练 P-Net。
3.1.10. LoRA (Low-Rank Adaptation)
LoRA (Low-Rank Adaptation) 是一种用于高效微调大型预训练模型(如 LLMs)的技术。它通过向预训练模型的权重矩阵中注入低秩矩阵来实现微调,而不是直接更新整个巨大的预训练模型权重。在微调过程中,只有这些小得多的低秩矩阵的参数被训练,从而大大减少了可训练参数的数量和计算资源需求,同时保持了与全量微调相近的性能。
3.2. 前人工作
3.2.1. 现有 ICL 鲁棒性缓解方法
论文提及了两种主要的现有缓解方法:
- 修改训练目标: 针对
Transformer的单向注意力机制的局限性进行修改 (Xiang et al., 2024),或设计排列等变(permutation-equivariant)架构 (Chen et al., 2023b)。但这些方法通常缺乏可扩展性。例如,InfoAC(Xiang et al., 2024) 通过对比学习来打破自回归约束,但其成功有限且仅限于分类任务。DeepSet架构 (Chen et al., 2023b) 虽然表现出更好的排列不变性,但其MLP-based 架构太小,无法解决复杂的语言建模任务。 - 解码阶段技术(后处理): 如输出校准 (Zhao et al., 2021) 和示例顺序优化 (Lu et al., 2022)。这些方法在推理时会引入额外的计算开销,限制了它们的实用性。例如,输出校准对生成任务适用性较差,而顺序优化存在指数级复杂度问题。
3.2.2. 分布鲁棒优化 (DRO) 的应用
DRO 框架在优化领域已有广泛研究,其目标是在模糊集(ambiguity set)内针对最坏情况分布优化目标函数 (Ben-Tal et al., 2013; Lam & Zhou, 2015; Duchi et al., 2016; Miyato et al., 2018)。此前 DRO 主要应用于解决分布漂移(distributional shifts)问题,例如标签漂移 (label shift) (Hu et al., 2018)、数据源漂移 (data source shift) (Oren et al., 2019) 和群组漂移 (group shift) (Sagawa et al., 2020)。本文是首次将 DRO 应用于通过定义在经验分布的所有可能排列上的模糊集来增强 LLM 的 ICL 鲁棒性,从而提供性能保证。
3.2.3. 最优传输 (OT) 在机器学习中的应用
最优传输 (OT) 是一个数学分支,自蒙日 (Monge, 1781) 和坎托罗维奇 (Kantorovich, 1942) 奠基以来,在机器学习中得到了广泛应用,例如用于分布匹配 (Montesuma et al., 2024; Xiao et al., 2024)。本文的工作借鉴了 OT 学习排列结构的思想 (Mena et al., 2018),但将其应用于 LLMs 的上下文,并设计了结合 Sinkhorn 算子的 P-Net 来生成挑战性排列,从而实现 LLM 的 DRO 训练。
3.3. 技术演进
ICL 最初由 GPT-3 提出 (Brown et al., 2020),展示了 LLMs 通过少量示例学习新任务的强大能力。然而,很快研究发现 ICL 对示例的顺序敏感 (Zhao et al., 2021; Lu et al., 2022),这成为了 ICL 实际应用中的一个痛点。早期的缓解方案多集中于推理阶段的后处理,如通过校准输出或搜索最优排列,但这些方法存在计算开销大、无法从根本上解决模型固有缺陷的问题。随后的研究开始探索训练阶段的改进,包括尝试修改 Transformer 架构以使其对顺序更不敏感,但这些方法往往牺牲了模型的通用性或可扩展性。
本文的工作代表了 ICL 鲁棒性研究的一个重要进展,它在训练阶段引入了 DRO 框架,旨在从根本上提升 LLM 的固有鲁棒性。通过将寻找最差排列问题转化为最优传输问题,并利用对抗性训练,PEARL 提供了一种既能有效提升鲁棒性又保持 LLM 架构不变的通用解决方案,从而避免了推理开销并保持了模型的可扩展性。
3.4. 差异化分析
PEARL 与现有方法的核心区别和创新点在于:
- 训练阶段的固有鲁棒性 vs. 推理阶段的后处理: 大多数现有方法(如输出校准、顺序优化)是在推理阶段通过预处理或后处理来缓解敏感性,这增加了推理成本且未能解决模型内在的脆弱性。
PEARL是一种训练阶段 (training-stage) 的方法,通过修改模型的训练范式来使其自身对排列具有更强的固有鲁棒性 (inherent robustness),一旦训练完成,推理时无需额外开销。 - 分布鲁棒优化 (DRO) 的应用:
PEARL首次将DRO框架引入LLM的ICL鲁棒性问题。与传统的ERM仅优化平均性能不同,DRO明确地针对最坏情况的排列进行优化,从而提供更强的性能保证。 - 对抗性排列生成机制:
PEARL引入了一个专门的排列提议网络 (P-Net) 作为“攻击者”,它能够高效地生成对LLM最具挑战性的排列。这种动态的、对抗性的排列生成机制优于静态的随机打乱()或简单的混合(),因为它能主动探索和学习导致模型失败的排列模式。 - 最优传输 (OT) 解决排列问题:
P-Net将生成对抗性排列问题建模为最优传输问题,并使用熵约束 Sinkhorn 算法进行求解。这提供了一种新颖且高效的方式来探索和利用排列空间,而传统方法往往因指数级排列数量而受限。 - 通用性和可扩展性:
PEARL不修改Transformer架构或其自回归目标,因此保留了LLM的可扩展性。实验证明,它能泛化到不同的LLM家族和多示例、长上下文场景,而一些基于架构修改的方法可能缺乏这种通用性。
4. 方法论
4.1. 方法原理
PEARL 的核心思想是通过分布鲁棒优化 (DRO) 框架,使 LLM 不仅在训练数据上表现良好,更能在面对各种可能的输入示例排列时,尤其是那些最能挑战模型性能的“最坏情况”排列时,依然保持鲁棒性。这通过一个两阶段的对抗性游戏来实现:一个排列提议网络 (P-Net) 作为“攻击者”,其目标是为 LLM 生成最困难的示例排列;而 LLM 作为“防御者”,则努力在 P-Net 提出的这些挑战性排列下最小化损失。通过这种最小最大优化 (minimax optimization) 的迭代过程,LLM 逐步学会抵御排列敏感性。
P-Net 的设计灵感来源于最优传输 (OT) 理论,它将寻找最具挑战性的排列过程视为将输入示例的原始分布“传输”到最坏情况排列分布的问题。利用 Sinkhorn 算法及其熵约束变体,P-Net 能够高效且可微分地生成近似的排列矩阵,用于重排 LLM 的输入示例。
4.2. 核心方法详解 (PEARL 框架)
4.2.1. 基于 DRO 的指令微调 (Instruction Tuning via DRO)
在少样本学习 (few-shot learning) 的监督微调 (supervised fine-tuning, SFT) 中,LLM 被训练以预测给定输入 和少样本指令 的输出 。指令 通常由一系列示例(demonstrations)组成,每个示例都是一个输入-输出对。
设 为语言模型的参数空间, 是一个非负损失函数,衡量模型预测与真实输出之间的差异。
标准经验风险最小化 (Empirical Risk Minimization, ERM) 传统的训练方法是找到最小化训练数据上经验损失的参数 : 其中, 表示从训练数据集中导出的经验分布。 符号解释:
-
:通过经验风险最小化得到的模型参数。
-
:找到使目标函数最小化的参数。
-
: 模型参数 属于参数空间 。
-
:在经验分布 下对样本
(p, x, y)的期望。 -
:给定模型参数 时,对于指令 、输入 和真实输出 的损失函数值。
然而,
ERM训练的模型往往在面对相同示例的不同排列时泛化能力差,因为训练集只覆盖了部分排列组合。
分布鲁棒优化 (Distributionally Robust Optimization, DRO)
为了系统地解决排列敏感性问题,本文提出在 DRO 框架下进行微调,该框架在指定的模糊集(ambiguity set)内优化最坏情况下的风险。具体目标是解决:
符号解释:
-
:通过分布鲁棒优化得到的模型参数。
-
:找到使括号内表达式最小化的模型参数 。
-
:找到使期望损失最大化的分布 ,该分布属于模糊集 。
-
:在分布 下对样本
(p, x, y)的期望。 -
:损失函数。
模糊集 被构建为通过排列经验分布 中的提示(
prompts)所获得的所有分布的凸包。具体定义为: 符号解释: -
:模糊集,包含所有可能的排列分布。
-
:模糊集中的一个分布,是不同排列分布 的凸组合。
-
: 是一个排列矩阵,用于重新排序提示 中的示例序列, 是所有此类排列矩阵的集合。
-
:对应于排列 的概率权重。
-
:向量 位于 维概率单纯形(
probability simplex)中,意味着所有 且 。 -
:这是将经验分布 中的每个样本
(p, x, y)的提示 经过排列 后形成的分布。通过
Figure 2可以直观理解ERM和DRO的区别: 以下是原文 Figure 2 的示意图:

该图像是包含两部分的图表,比较了在ERM和DRO训练范式下模型学习的分布差异。蓝色柱状代表训练数据的经验分布 ,紫色曲线表示模型学习的分布 ,展示了在出现频率较低但有效的排列组合上的不同表现。
Figure 2: Comparison of models trained under ERM and DRO paradigms. The blue bars represent the empirical distribution of training data, showing different frequencies of six permutations in the training set. The purple curves denote the learned distribution by (a) ERM and (b) DRO models, illustrating their different behaviors on less appeared but valid permutations.
图 2 (a) 显示,ERM 训练的模型倾向于为训练数据中频繁出现的排列(如 0, 1, 4)分配更高的概率,而对不常出现的排列(如 2, 3, 5)概率较低,导致在测试时遇到这些未充分训练的排列时性能不佳。相反,图 2 (b) 中的 DRO 训练模型则更均匀地分布概率,因为它在学习过程中明确考虑了所有可能的排列,从而鼓励模型对所有有效排列分配合理的概率,无论它们在训练数据中的出现频率如何。
4.2.2. 通过 P-Net 学习生成排列 (LEARniNG tO GENERaTE PERMutatioNS via P-NET)
为了实现 DRO 框架,本文需要高效地找到模糊集中的最坏情况(即解决 sup 步骤)。由于排列组合空间呈指数级增长,直接通过穷举搜索是不可行的。
本文通过引入排列提议网络 (P-Net) 来解决这一挑战,P-Net 被定义为 P-Net:( \mathcal { P } \times \mathcal { X } \times \mathcal { Y } ) \to \Delta ( \Pi )
,它学习一个关于排列的分布,以增加给定输入示例对 `LLM` 的任务难度。`P-Net` 从这个分布中采样出挑战性排列,用于重新排序给定的示例。
`P-Net` 包含两个主要组件:一个参数组件(特征提取和关系建模)和一个非参数组件(使用 `Sinkhorn` 算法构建排列分布)。
**参数组件 (Parameter component)**
参数组件包含一个特征提取器和一层交叉关系建模层。
1. **特征提取器 (Feature extractor):** 这是一个编码器模型 (`Encoder model`),它接收一个由 个示例对组成的 `ICL` 提示 以及一个待预测样本 `( x , y )` 作为输入。它输出它们的表示:
\left( [ \mathrm { C L S } ] , ( x _ { 1 } , y _ { 1 } ) , \dots , [ \mathrm { C L S } ] , ( x _ { n } , y _ { n } ) , [ \mathrm { C L S } ] , ( x , y ) \right) \xrightarrow { \mathrm { Encoder } } \left( \mathbf { h } _ { 1 } , \mathbf { h } _ { 2 } , \dots , \mathbf { h } _ { n } , \mathbf { h } _ { n + 1 } \right) ,
**符号解释:**
* : 特殊的分类令牌,常用于提取序列的整体表示。
* `( x _ { i } , y _ { i } )`: 第 个示例的输入-输出对。
* `( x , y )`: 待预测样本的输入-输出对。
* : 编码器模型,如 `BERT` 的编码器。
* : 对应于第 个 `[CLS]` 令牌的表示向量,维度为 。
2. **交叉示例关系建模层 (Cross-relationship modeling layer):** 提取 个示例的表示后,我们得到 。然后,设计一个简单的交叉示例层来获取一个关系矩阵 ,该矩阵捕获了每个示例对之间的关系:
\mathbf { R } = g \left( H W H ^ { \top } \right) ,
**符号解释:**
* : 关系矩阵。
* : 非线性激活函数。
* : 包含 个示例表示的矩阵,每一行是一个示例的表示 。
* : 可学习的权重矩阵。
* : 矩阵 的转置。
矩阵 可以被理解为图论中的邻接矩阵,其中示例是节点,`R _ { i j }` 表示示例 和 之间的关系。具体来说,`R _ { i j }` 被定义为如果交换示例 和 ,`LLM` 任务难度可能增加的潜力;`R _ { i j }` 值越高,表示交换这两个示例可能对预测产生越显著的影响。因此,这个参数组件建模了一个边预测过程。
**非参数组件 (Non-parameter component)**
虽然 捕获了示例之间交换的潜力,但它还不能直接用于采样排列,因为它不是一个有效的概率分布。非参数组件旨在将邻接矩阵 转换为一个双随机矩阵,代表一个关于排列的概率分布。
1. **Sinkhorn 算子 (Sinkhorn operator):** 参照 (Adams & Zemel, 2011; Mena et al., 2018),本文采用 `Sinkhorn` 算子 通过迭代的行和列归一化过程来获得此类矩阵:
S ( R ) = \operatorname* { l i m } _ { l \to \infty } ( \mathcal { T } _ { c } ( \mathcal { T } _ { r } ( \exp ( R ) ) ) ) ,
\mathcal T _ { r } ( R ) = R \oslash \left( R \mathbf { 1 } _ { n } \mathbf { 1 } _ { n } ^ { \top } \right) , \quad \mathcal T _ { c } ( R ) = R \oslash \left( \mathbf { 1 } _ { n } \mathbf { 1 } _ { n } ^ { \top } R \right) ,
**符号解释:**
* `S ( R )`: 经过 `Sinkhorn` 算子处理后的矩阵,最终会收敛到一个双随机矩阵。
* : 当迭代次数 趋于无穷大时的极限。在实际应用中,会进行有限次迭代。
* : 列归一化算子。
* : 行归一化算子。
* : 对矩阵 的每个元素进行指数运算。
* : 逐元素除法。
* : 一个 维的全一列向量。
* : 一个 维的全一行向量。
`Sinkhorn` 算子通过交替的行和列归一化,使得矩阵的行和列之和都趋近于 1,从而得到一个双随机矩阵。
2. **Gumbel 技巧 (Gumbel trick) 进行可微分采样:** 为了确保从排列分布中采样时过程是可微分的,应用了 `Gumbel` 技巧:
\Pi = \operatorname* { l i m } _ { \tau \to 0 } S ( ( R + G ) / \tau ) ,
G _ { i j } = - \log \left( - \log G _ { i j } ^ { \prime } \right) , \quad G _ { i j } ^ { \prime } \sim U ( 0 , 1 ) ,
**符号解释:**
* : 最终采样的排列矩阵。
* : 当温度参数 趋近于 0 时的极限。在实际应用中, 是一个小的正数。
* : 将关系矩阵 加上 `Gumbel` 噪声 ,并通过温度 进行缩放后,再应用 `Sinkhorn` 算子。
* : `Gumbel` 噪声矩阵。
* `G _ { i j }`: 矩阵 中的元素,通过双对数变换从均匀分布 采样的 生成。
当 趋近于 0 时,上述操作的结果近似于一个排列矩阵 。通过将排列生成视为一个最优传输问题,并由 `P-Net` 实现,可以将输入排列分布转化为目标分布。接下来,介绍 `P-Net` 如何与 `LLM` 协同优化。
以下是原文 Figure 3 的示意图:

*该图像是论文图3的示意图,展示了PEARL框架中P-Net和LLM的联合训练流程。P-Net基于输入示例序列生成置换矩阵,通过采样得到最挑战性的排列,并作用于输入后传入LLM计算损失,最后通过反向传播迭代优化。*
Figure 3: An overview of the learning framework. The P-Net is a small model incorporating the Sinkhorn operator, trained jointly with the LLM under the adversarial optimization algorithm. Note that the permutation matrix operates on the input sequence's embeddings (simplified here as text sequences for clarity). After training, only the LLM is retained while the P-Net is discarded.
图 3 展示了 `PEARL` 的整体学习框架。`P-Net`(一个小型模型,包含 `Sinkhorn` 算子)与 `LLM` 在对抗性优化算法下联合训练。`P-Net` 接收原始的提示、输入和输出,通过特征提取、关系建模和 `Sinkhorn` 算子生成一个排列矩阵 。这个排列矩阵作用于输入示例的嵌入(为了清晰,图中简化为文本序列),生成一个重排后的提示 。这个重排后的提示连同输入 和输出 被送入 `LLM`,`LLM` 计算损失。`P-Net` 试图最大化这个损失(通过其参数 ),而 `LLM` 试图最小化这个损失(通过其参数 )。通过这种对抗性训练,`P-Net` 学会生成最具挑战性的排列,而 `LLM` 则学会在这些挑战性排列下保持鲁棒性。训练完成后,`P-Net` 会被丢弃,只保留 `LLM`。
### 4.2.3. 对抗性优化 (ADVERSARIAL OPTIMIZATION)
本文采用一个对抗性优化框架来联合训练 `LLM` 和 `P-Net`。设 和 分别表示 `LLM` 和 `P-Net` 的参数。对于从经验分布 中抽取的每个样本 `( p , x , y )`,`P-Net` 生成一个对抗性排列 ,以最大化 `LLM` 的损失。作为回应,`LLM` 旨在最小化其在 `P-Net` 操纵下的损失。
**LLM 的损失函数 (LLM's loss function)**
L _ { \mathrm { l m } } ( \phi ; \theta ) = \mathbb { E } _ { ( p , x , y ) \sim \hat { P } , \Pi \sim \mathrm { P - N e t } ( \phi ; p , x , y ) } [ \ell ( \theta ; ( \Pi \cdot p , x , y ) ) ]
**符号解释:**
* : `LLM` 的损失,依赖于 `P-Net` 的参数 (因为它决定了 )和 `LLM` 的参数 。
* : 在经验分布 中采样样本,并由 `P-Net`(参数为 )生成排列 的期望。
* : `LLM`(参数为 )在经过排列的提示 、输入 和真实输出 上的损失。
**熵正则化项 (Entropy-based regularization term)**
为了防止 `P-Net` 陷入琐碎解(例如,生成统一的排列以降低示例语义的有效性),引入了一个基于熵的正则化项:
L _ { \mathrm { e n t } } ( \phi ) = \mathbb { E } _ { ( p , x , y ) \sim \hat { P } , \Pi \sim \mathrm { P - N e t } ( \phi ; p , x , y ) } [ \mathcal { H } ( \Pi ) ] ,
**符号解释:**
* : 熵正则化项,鼓励 `P-Net` 生成更多样化的排列。
* : 逐元素熵函数,衡量排列 的随机性或多样性。
**两玩家最小最大优化问题 (Two-player min-max optimization problem)**
这导致了一个具有以下目标的 (`min-max`) 优化问题:
\operatorname* { m i n } _ { \theta } \operatorname* { m a x } _ { \phi } \left( L _ { \mathrm { l m } } ( \phi ; \theta ) - \beta L _ { \mathrm { e n t } } ( \phi ) \right) ,
**符号解释:**
* : `LLM` 的目标是最小化整个表达式。
* : `P-Net` 的目标是最大化整个表达式。
* : `LLM` 的损失。`P-Net` 试图通过选择 来最大化它,`LLM` 试图通过选择 来最小化它。
* : 熵正则化项。前面的负号和最大化操作意味着 `P-Net` 实际上是试图**最小化**熵,即找到更确定性的、更具挑战性的排列。然而,原文的意图是**防止 P-Net 坍缩到平凡解**,这通常需要**最大化熵**以鼓励多样性。这里原文公式 在 下,意味着 `P-Net` 试图最大化 并最小化 。这与一般认知中通过**最大化熵**来鼓励多样性以避免坍缩的目的略有出入。**此处应理解为,P-Net 在最大化 LLM 损失的同时,被 项以 为变量的梯度惩罚,从而阻止其生成过于单一的、坍缩的排列。** 也就是说,如果 `P-Net` 倾向于生成非常确定性的(低熵的)排列,它会因为负的熵项而受到惩罚,除非它能够显著提高 `LLM` 的损失。因此,`P-Net` 需要在最大化 `LLM` 损失和保持一定程度的排列多样性之间找到平衡。
* : 控制熵正则化强度的超参数。
**算法 (Algorithm 1)**
本文采用交替优化(`alternating optimization`)来迭代更新 和 。完整的训练过程在 `Algorithm 1` 中详细说明:
以下是原文 Algorithm 1 的内容:
Input: θ, φ (LLM, P-Net); ηθ, ηφ (learning rates); m (inner steps); β (entropy coefficient) repeat for t = 1 to m do (p, x, y) ∼ P ; // Sample training examples Ⅱ ∼ P-Net(φ, p, x, y) ; // Generate permutations L1m(φ, θ) ← e(θ; Π · p, x, y) ; // Compute LLM loss Lent(φ) ← H(Π) ; // Compute entropy regularization φ ← φ + ηφφ(Llm − βLent) ; // Update P-Net end θ ← θ − ηθθL1m(φ, θ) ; until convergence; // Update LLM
**算法 1: PEARL 的对抗性优化算法**
* **输入:**
* : `LLM` 的参数。
* : `P-Net` 的参数。
* : `LLM` 的学习率。
* : `P-Net` 的学习率。
* : `P-Net` 的内部更新步数。
* : 熵系数。
* **重复(直到收敛):**
* **P-Net 内部更新循环 (for t = 1 to m do):**
1. 从数据集 中采样训练示例 `(p, x, y)`。
2. 使用 `P-Net`(参数 )生成排列 :。
3. 计算 `LLM` 的损失:。
4. 计算熵正则化项:。
5. 更新 `P-Net` 的参数:。
(这里 表示对 `P-Net` 参数 的梯度。`P-Net` 目标是最大化 ,所以是梯度上升。)
* **LLM 更新 (在 P-Net 内部循环结束后):**
1. 更新 `LLM` 的参数:。
(这里 表示对 `LLM` 参数 的梯度。`LLM` 目标是最小化 ,所以是梯度下降。)
这种交替优化使得 `P-Net` 学习生成更困难的排列,而 `LLM` 则学习在这种困难下保持性能。理想情况下,当收敛时,`P-Net` 将代表一个均匀分布在所有排列上,因为 `LLM` 能够同样好地处理所有可能的排列。
# 5. 实验设置
## 5.1. 数据集
`PEARL` 方法在两种场景下进行验证:
1. **合成预训练任务 (Synthetic pre-training task):**
* **任务类型:** 上下文学习线性函数 。
* **数据生成:**
* 函数采样:权重向量 采样,定义线性函数 ,其中 是输入维度。
* 输入采样: 独立抽取。
* 输出生成:。
* **提示结构:** 提示 包含 个输入-输出示例对和一个查询输入 ,形式为 `( x _ { 1 } , f ( x _ { 1 } ) , x _ { 2 } , f ( x _ { 2 } ) , . . . , x _ { i } , f ( x _ { i } ) , x _ { i + 1 } )`。
* **数据集规模:** 预训练模型在生成了 40k 个线性函数的数据集上。
* **目的:** 评估模型在控制良好、可重复的数学任务中对排列的鲁棒性。
2. **真实世界指令微调任务 (Real-world instruction tuning tasks):**
* **数据集来源:** `Super-Natural Instructions` (Wang et al., 2022),该数据集是 `FLAN v2 benchmark` (Chung et al., 2024) 的一部分。
* **任务选择:** 选择了 17 个有代表性的任务,包括 9 个自然语言生成 (NLG) 任务和 8 个自然语言理解 (NLU) 任务。
* **数据划分:**
* 4 个数据集被随机指定为**保留测试集 (held-out test sets)**。
* 剩余 13 个数据集用于训练。
* **样本数量:** 每个训练数据集包含 150 个示例,每个测试数据集包含 100 个示例。
* **总计:** 训练集共 1,950 个示例,测试集共 400 个示例。
* **目的:** 评估 `PEARL` 在更复杂、多样化的真实世界 `NLP` 任务中的有效性。
以下是原文 Table 2 的数据:
Split Category
Training
NLG
7
# Tasks # Samples 1050
NLU
6
900
Testing
NLG
2
200
NLU
2
200
以下是原文 Table 5 的数据:
Task ID
Task Name
Source
Category
1297
QASC Question Answering
QASC
Question Answering
442
COM_QA Paraphrase Question Generation
COM_QA
Question Rewriting
908
DialogRE Identify Familial Relationships
DialogRE
Speaker Relation Classification
288
Gigaword Summarization
Gigaword
Title Generation
582
Natural Questions Answer Generation
Natural Questions
Question Answering
151
TOMQA Find Location Easy Clean
TOM_QA
Question Answering
1714
ConvAI3 Sentence Generation
ClariQ
Dialogue Generation
379
AGNews Topic Classification
AG News
Text Categorization
639
MultiWOZ User Utterance Generation
MultiWOZ 2.2
Dialogue Generation
209
Stance Detection Classification
StarCon
Stance Detection
1516
IMPPRES Natural Language Inference
IMPPRES
Textual Entailment
589
Amazon Food Summary Text Generation
Amazon Reviews
Summarization
1285
KPA Keypoint Matching
ArgKP
Text Matching
### 5.1.1. 数据集中的具体样本示例
原文中并未直接提供数据集的具体样本示例,但根据任务类型可以进行推断:
* **CurDial (Curiosity-based Dialog):** 涉及对话的生成或理解,例如在对话中找出好奇心相关的语句。
* **TMW (TellMeWhy QA):** 涉及问答,可能需要模型解释某个现象或事实的原因。
* **CSQA (CommonsenseQA):** 涉及常识问答,例如“什么能在水里漂浮?”。
* **CoLA (Corpus of Linguistic Acceptability):** 涉及判断句子的语法正确性。
## 5.2. 评估指标
### 5.2.1. 归一化均方误差 (Normalized Squared Error)
* **概念定义:** `Normalized Squared Error`(归一化均方误差)是衡量模型预测值与真实值之间差异的指标,特别是在本文的线性函数拟合任务中。它计算预测值与真实值之差的平方,然后除以问题的维度进行归一化,以消除维度对误差大小的影响,使得不同维度的问题之间可以进行比较。较低的值表示更好的性能。
* **数学公式:**
\overline { { ( L M ( p ) - w ^ { \top } } } \dot { x } _ { \mathrm { q u e r y } } ) ^ { 2 } / d )
* **符号解释:**
* `L M ( p )`: 语言模型 `LM` 在给定提示 下对查询的预测输出。
* : 查询输入 对应的真实输出(通过线性函数 计算得到)。
* : 问题的维度。
* 上方的横线表示取平均值。
### 5.2.2. ROUGE-L (Recall-Oriented Understudy for Gisting Evaluation - Longest Common Subsequence)
* **概念定义:** `ROUGE-L` 是一种用于评估文本生成任务(如摘要、翻译、问答)性能的指标。它通过计算候选文本(模型生成)与参考文本(真实标签)之间最长公共子序列(`Longest Common Subsequence, LCS`)的长度来衡量它们的相似性。`ROUGE-L` 关注的是文本的流畅性和结构匹配度,因为它不需要连续的匹配,只要单词顺序相对一致即可。值越高表示生成文本与参考文本越相似。
* **数学公式:**
`ROUGE-L` 通常定义为基于 `LCS` 的 F1 分数,但为了简化和在实践中常用,我们关注其基于 `LCS` 的召回率和精确率,并最终计算 F1。
设 `LCS(C, R)` 是候选文本 和参考文本 的最长公共子序列长度。
召回率 (Recall):
R_{LCS} = \frac{LCS(C, R)}{\text{length}(R)}
精确率 (Precision):
P_{LCS} = \frac{LCS(C, R)}{\text{length}(C)}
F1 分数:
F_{LCS} = \frac{(1 + \beta^2) R_{LCS} P_{LCS}}{R_{LCS} + \beta^2 P_{LCS}}
在 `ROUGE` 包的默认实现中,通常 ,即 F1 分数是召回率和精确率的调和平均值。
* **符号解释:**
* : 候选文本(`LLM` 生成的文本)。
* : 参考文本(真实输出或人类标注的文本)。
* `LCS(C, R)`: 候选文本 和参考文本 的最长公共子序列的长度。
* : 参考文本 的长度(通常是单词或词元的数量)。
* : 候选文本 的长度(通常是单词或词元的数量)。
* : 基于 `LCS` 的召回率。
* : 基于 `LCS` 的精确率。
* : 基于 `LCS` 的 F1 分数。
* : 用于调整召回率和精确率重要性的参数,通常取 1。
### 5.2.3. 攻击成功率 (Attack Success Rate, ASR)
* **概念定义:** `Attack Success Rate (ASR)` 衡量攻击者通过排列 `ICL` 示例来降低 `LLM` 性能的有效性。如果模型在攻击后的性能下降超过预设阈值 ,则认为攻击成功。更高的 `ASR` 表示模型对排列攻击越脆弱。
* **数学公式:**
\mathrm { A S R } ( D , \delta ) = \frac { 1 } { | D | } \sum _ { i = 1 } ^ { | D | } \mathbb { I } \big ( ( \mu _ { i } - \omega _ { i } ) / \mu _ { i } \geq \delta \big )
* **\text{符号解释}:**
* $\mathrm { A S R } ( D , \delta )$: \text{数据集} $D$ \text{上,给定性能下降阈值} $\delta$ \text{时的攻击成功率。}
* $| D |$: \text{数据集} $D$ \text{的大小(样本数量)。}
* $\sum _ { i = 1 } ^ { | D | }$: \text{对数据集中的所有样本求和。}
* $\mathbb { I } ( \cdot )$: \text{指示函数,如果条件为真则返回} 1\text{,否则返回} 0\text{。}
* $\mu _ { i }$: \text{第} $i$ \text{个样本的平均性能。}
* $\omega _ { i }$: \text{第} $i$ \text{个样本在攻击策略下的受损性能。}
* $\delta$: \text{性能下降的阈值(例如} 50%\text{)。}
<strong>\text{平均性能} (Average Performance) $\mu_i$:</strong>
\mu _ { i } = \mathbb { E } _ { \Pi \sim \mathbb { P } } [ g ( \Pi \cdot p _ { i } , x _ { i } ; y _ { i } ) ] = { \frac { 1 } { n ! } } \sum _ { j = 1 } ^ { n ! } g ( \Pi _ { j } \cdot p _ { i } , x _ { i } ; y _ { i } )
**符号解释:**
* : 第 个样本在所有可能排列下的平均性能。
* : 在所有可能排列 上取期望。
* : 性能指标函数(如 `ROUGE-L` 或 `Normalized Squared Error`),评估模型在经过排列 的提示 、输入 和真实输出 上的表现。
* `n!`: 示例数量 的阶乘,表示所有可能的排列数量。
**受损性能 (Compromised Performance) :**
* **穷举搜索攻击 (Exhaustive Search Attack):** 攻击者假设拥有无限尝试次数,通过测试所有可能的排列来找到产生最差性能的排列。
\omega _ { i } = \operatorname* { m i n } _ { \Pi \in \mathbb { P } } g ( \Pi \cdot p _ { i } , x _ { i } ; y _ { i } )
**符号解释:**
* : 第 个样本在穷举搜索攻击下的最差性能。
* : 找到在所有可能排列 中使性能指标 最小化的排列。
* **神经搜索攻击 (Neural Search Attack):** 使用 `P-Net` 来近似穷举搜索攻击的上限,以有限尝试次数生成最具挑战性的排列。
\begin{array} { r } { \omega _ { i } = g ( \Pi _ { i } \cdot p _ { i } , x _ { i } ; y _ { i } ) , \qquad \mathrm { s . t . } ~ \Pi _ { i } \sim \mathrm { P - N e t } ( p _ { i } , x _ { i } , y _ { i } ) } \end{array}
\$\$
**符号解释:**
* : 第 个样本在神经搜索攻击下的性能。
* : 排列 是由训练好的 `P-Net` 基于样本 生成的。
5.3. 对比基线
5.3.1. 线性函数 ICL 任务的基线
- ERM + CL (Empirical Risk Minimization + Curriculum Learning): 这是一种标准的经验风险最小化方法,结合了课程学习 (
curriculum learning)。训练过程逐渐增加呈现给模型的示例数量,以促进对更复杂模式的渐进式学习,并使训练更稳定。
5.3.2. 指令微调任务的基线
- ERM (Empirical Risk Minimization): 标准的监督微调方法,旨在最小化训练数据集上的平均损失。这是大多数主流指令微调模型(如
FLAN,Natural Instructions,MetaICL)采用的方法。 - ERM + DS (ERM with Demonstration Shuffling): 在每个训练步骤中,随机打乱每个样本内上下文示例的顺序,以此增强
ERM。这可以被视为一种epoch级别的数据增强形式,通过让模型接触不同的排列组合来引入鲁棒性。 - ERM + IM (ERM with Instance Mixup): 在每个训练步骤中融入
Instance Mixup技术。对于每个数据点,通过随机选择不同的上下文示例生成多个增强版本。为每个增强版本计算损失,然后对这些损失取平均,并执行一次反向传播。这提供了比简单打乱更细粒度的数据增强。通过与该基线比较,可以对比“min-mean optimization”与“min-max optimization”的区别。 - InfoAC (Information-Aware Contrastive Learning): 这是一种训练阶段的方法,通过对比学习让早期词元能够访问后期词元的信息,旨在缓解自回归语言模型固有的
ICL顺序敏感性。然而,其成功有限且通常限于分类任务。
5.3.3. 使用的 LLMs
Llama3-8BLlama2-7BLlama2-13BMistral-7BGemma-7B
5.4. 实现细节与超参数 (Implementation Details and Hyperparameters)
5.4.1. 线性函数 ICL 任务
- LLM 架构: 使用
GPT-2基础模型 (Radford et al., 2019),包含 12 层、8 个注意力头、隐藏维度 256。模型接收嵌入空间中的向量序列作为输入,并预测同一空间中的下一个向量。 - 训练: 从头开始在生成的 40k 个线性函数数据集上进行预训练,使用
AdamW优化器 (Loshchilov & Hutter, 2019)。批量大小为 128,训练 500k 步,根据验证集性能选择最佳检查点。 - P-Net 初始化: 随机初始化一个
BERT-base大小的Transformer编码器,也从头开始训练。 - 测试: 采样新函数来评估模型通过上下文示例推断新权重 的能力。
5.4.2. 指令微调任务
- LLM 架构:
Llama3-8B模型作为目标LLM。 - P-Net 架构:
FLAN-large编码器作为P-Net。 - 微调方法: 两个模型都使用
LoRA(Hu et al., 2022) 进行微调。P-Net的微调参数数量是LLM的 1/20。 - 训练环境: 单块 NVIDIA A40 GPU。
- 训练轮次: 在指令数据集上训练两个
epoch。 - 批量大小: 16。
- 训练步数: 总计 246 步。
- 优化器:
AdamW。 - 学习率:
P-Net: 。LLM: 。
- Sinkhorn 算法参数:
-
迭代次数: 80。
-
温度参数 : 0.1。
-
熵约束系数 : 1.0。
以下是原文 Table 6 的数据:
Category Hyperparameter Value LLMs Learning rate 3e-5 Batch size 16 Max sequence length 512 Weight decay coefficient 0.1 Epoch 2 LoRA Rank 8 Alpha 32 Dropout 0.1 P-Net target modules q, Vv LLMs target modules q_proj, k_proj, v_proj, o_proj, gate_proj, P-Net Temperature up_proj, down_proj 0.1 Iteration coefficient 80 Entropy constraint 1.0 Noise 0.3 Learning rate 1e-4 Batch size 16 Max sequence length 512
-
6. 实验结果与分析
6.1. 对 LLM 排列脆弱性的重新审视 (2 REVIsITinG PERMutatIoN VULnERABiLity In LLMS)
本节通过对 LLaMA-3 模型的广泛实验,重新评估了其对示例排列的脆弱性,并从对抗性角度探讨了这种脆弱性是否可以被利用来设计有效攻击。
以下是原文 Figure 1 的示意图:

该图像是图表,展示了Llama-3在CurDial和TMW数据集上的性能与攻击成功率。左侧图表显示不同示例数量下的随机、平均和最差表现;右侧图表展示在不同阈值下穷举和神经搜索攻击方法的成功率。
Figure 1: Performance and attack success rates of Llama-3 on CurDial and TMW datasets. Left panels: Random, average and worst-case performance as a function of shot number. Right panels: Attack success rates for exhaustive and neural search attack methods at different thresholds.
左图分析:不同示例数量下的性能表现
左图展示了 Llama-3 在 CurDial 和 TMW 数据集上,不同示例数量(shot number)下的随机、平均和最差性能。
- 平均性能 (Average performance): 随着示例数量的增加,模型的平均性能通常会提高,这表明更多的上下文信息有助于模型更好地理解任务。
- 最差性能 (Worst-case performance): 令人担忧的是,随着示例数量的增加,模型的最差性能反而会恶化。例如,在
CurDial数据集上,当shot number从 2 增加到 6 时,平均性能有所提升,但最差性能却明显下降。这印证了论文的观点:增加示例数量是一把双刃剑。虽然提供了更丰富的上下文,但呈指数级增长的排列数量 (n!) 增加了模型遇到表现极差的特定输入配置的风险。这强调了LLM在ICL中对排列的固有敏感性。
右图分析:排列攻击的有效性
右图展示了在不同阈值 下,穷举搜索攻击 (Exhaustive Search Attack) 和神经搜索攻击 (Neural Search Attack) 的攻击成功率 (Attack Success Rate, ASR)。
-
攻击有效性: 结果表明,排列攻击非常有效且易于实现。在 的情况下,穷举搜索攻击在两个数据集上分别成功攻击了超过 50% 和 80% 的样本。
-
神经搜索攻击的近似能力: 神经搜索攻击的成功率接近穷举搜索攻击的上限,这表明
P-Net能够有效地找到具有挑战性的排列,即使在有限尝试次数下也能逼近最坏情况。 -
漏洞的严重性: 即使是像
LLaMA-3这样先进的LLM,也对简单的、只改变示例顺序而不改变语义内容的攻击高度敏感。这种攻击难以被模型提供商检测,但却能显著损害LLM性能,凸显了严重的安全和可靠性问题。总结: 本节的实验结果有力地论证了
LLM在ICL中的排列脆弱性是一个真实且严重的挑战,为PEARL框架的提出提供了强烈的动机。这些缺陷可能直接源于标准ERM训练的局限性,它优化平均性能却忽视了最坏情况性能。
6.2. 线性函数上下文学习 (4 In-Context Learning with Linear Functions)
本节评估了 PEARL 方法在合成线性函数 ICL 任务中的性能,与基线 进行了比较,主要关注对排列的鲁棒性。
以下是原文 Table 1 的数据:
| Shot | Method | Avg. | Worst. |
|---|---|---|---|
| 3 | ERM+CL | 1.45 | 2.67 |
| PEARL | 0.86 (+40.7) | 0.92 (+65.5) | |
| 4 | ERM+CL | 1.20 | 3.34 |
| PEARL | 0.79 (+34.1) | 1.11 (+66.8) | |
| 5 | ERM+CL | 1.28 | 5.03 |
| PEARL | 0.87 (+32.0) | 1.33 (+73.6) |
表 1 分析:不同排列下的归一化 MSE
- 基线方法 的脆弱性:
- 方法在平均性能和最差性能之间存在显著差距,表明其对排列具有极强的脆弱性。
- 随着
shots数量的增加,最差性能相对于平均性能的下降幅度急剧增大:从 3shots时的 74.6% 性能下降到 4shots时的 84.1%。这意味着增加示例数量虽然可能提高平均性能,但却大大增加了模型在某些特定(糟糕)排列下表现极差的风险。
PEARL的性能优势:-
PEARL不仅提高了平均性能(例如,3shots时提升 40.7%),还显著增强了最差情况下的泛化性能。 -
值得注意的是,
PEARL在最差性能上的提升幅度持续增加:从 3shots时的 65.5% 提升到 5shots时的 73.6%。这表明PEARL能够有效抵御最恶劣的排列攻击,使得模型在各种排列下都更加稳定。以下是原文 Figure 4 的示意图:
该图像是图表,展示了不同训练方法下在不同阈值比例下的攻击成功率对比。图中曲线比较了ERM+CL和PEARL在3、4、5个样本条件下的表现,显示PEARL在多种阈值下均有较低的攻击成功率,体现其对排列攻击的鲁棒性。
-
Figure 4: Comparison of attack success rates.
图 4 分析:攻击成功率对比
图 4 比较了 PEARL 和基线方法在不同攻击成功阈值 下的攻击成功率,以及不同示例数量(shots)下的防御能力。
-
PEARL的防御能力:PEARL的防御能力(即攻击成功率较低)随着攻击阈值 的增加而更加显著。当 时,PEARL在所有shots数量下的防御成功率大约是基线方法的两倍。这意味着PEARL能够有效阻止那些导致模型性能大幅下降(悲观场景)的攻击。 -
可扩展性:
PEARL的性能随着shots数量的增加而提高,表明其相比基线方法具有更好的可扩展性。即使在更多的示例下,PEARL也能保持较低的攻击成功率,增强了其在复杂场景下的实用性。总结: 在合成线性函数
ICL任务中,PEARL显著提高了LLM对排列的鲁棒性,尤其是在最坏情况性能和防御攻击方面表现出色。这为PEARL在更复杂的真实世界任务中的应用奠定了基础。
6.3. 大型语言模型的指令微调 (5 Instruction Fine-Tuning of Large Language Models)
本节评估了 PEARL 在真实世界指令微调任务中的性能,从三个角度进行了分析:与训练阶段方法的比较、对不同类型 LLM 的泛化能力、以及在多示例和长序列场景下的可扩展性。
6.3.1. 与训练阶段方法的比较
以下是原文 Table 3 的数据:
| Average | CSQA | CurDial | CoLA | TMW | |||||||
| # Shot | Method | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. |
| 2 | ERM | 57.3 | 49.4 | 58.0 | 54.0 | 57.9 | 43.4 | 62.0 | 58.0 | 51.1 | 42.0 |
| ERM+DS | 57.5 (-0.2) | 48.6 (-1.6) | 62.0 | 54.0 | 54.1 | 37.8 | 61.0 | 60.0 | 51.5 | 42.7 | |
| ERM+IM | 53.5(-6.6) | 44.4 (-10.1) | 63.0 | 54.0 | 44.7 | 28.1 | 57.0 | 56.3 | 49.4 | 39.2 | |
| INFOAC | 55.7 (-2.9) | 47.6 (-3.7) | 57.5 | 56.0 | 53.4 | 36.4 | 63.0 | 61.5 | 48.7 | 37.3 | |
| PEARL | 62.9 (+9.8) | 56.4 (+14.2) | 65.0 | 62.0 | 60.3 | 50.7 | 71.0 | 68.0 | 55.1 | 44.8 | |
| 3 | ERM | 57.8 | 38.3 | 57.7 | 47.0 | 61.4 | 25.9 | 61.9 | 52.0 | 50.3 | 29.4 |
| ERM+DS | 56.1 (-2.9) | 39.7 (+3.7) | 60.0 | 46.0 | 54.1 | 25.4 | 60.0 | 56.0 | 50.3 | 31.5 | |
| ERM+IM | 55.3 (-4.3) | 39.8(+3.9) | 59.0 | 46.0 | 54.6 | 28.0 | 57.6 | 553.1 | 50.0 | 31.9 | |
| INFOAC | 56.3 (-2.6) | 39.5(+3.1) | 59.3 | 49.0 | 55.2 | 24.3 | 62.1 | 55.8 | 48.4 | 28.8 | |
| PEARL | 63.1 (+9.2) | 46.9 (+22.5) | 68.4 | 62.0 | 66.7 | 34.8 | 64.7 | 56.0 | 52.4 | 34.7 | |
| 4 | ERM | 59.7 | 30.6 | 61.3 | 38.0 | 62.9 | 21.3 | 63.3 | 45.8 | 51.1 | 17.5 |
| ERM+DS | 57.7 (-3.4) | 31.8 (+3.9) | 63.3 | 40.0 | 57.3 | 17.6 | 60.1 | 52.0 | 49.9 | 17.8 | |
| ERM+IM | 56.0 (-6.2) | 32.4(+5.9) | 63.2 | 42.0 | 53.7 | 17.8 | 57.6 | 48.5 | 49.6 | 21.3 | |
| INFOAC | 58.6(-1.8) | 33.0 (+7.8) | 63.7 | 44.0 | 58.7 | 19.0 | 63.9 | 51.0 | 48.1 | 17.0 | |
| PEARL | 63.1 (+5.7) | 39.6 (+29.4) | 68.4 | 52.0 | 69.2 | 31.3 | 64.7 | 52.0 | 50.1 | 23.0 | |
表 3 分析:Llama3-8B 在四个保留任务上的平均和最差性能
PEARL的全面提升:PEARL在所有未见任务中持续改进了平均性能和最差性能。- 最差性能增益显著: 随着
shots数量的增加,PEARL相对于ERM的最差性能增益逐渐增加,从 2shots时的 14.2% 提高到 4shots时的 29.4%。这表明PEARL能够有效提高模型在面对极端不利排列时的鲁棒性。 - 平均性能提升: 尽管
PEARL优化的是最坏情况性能,但它也实现了卓越的平均性能,增益范围在 5.7% 到 9.8% 之间。这可能是因为在先进的LLM(如Llama3)上,训练损失迅速收敛,此时专注于具有挑战性的排列比使用随机排列更有效。这与Xu等人 (2024) 的观察一致。 - 其他基线的局限性:
- (Demonstration Shuffling) 和 (Instance Mixup) 旨在通过数据增强提高鲁棒性,但在大多数情况下,它们对平均性能的提升有限,甚至可能略有下降,对最差性能的提升也不如
PEARL显著。 InfoAC的表现也不如PEARL,这可能因为其方法主要针对分类任务,且未能全面解决Transformer的自回归约束导致的排列敏感性问题。
- (Demonstration Shuffling) 和 (Instance Mixup) 旨在通过数据增强提高鲁棒性,但在大多数情况下,它们对平均性能的提升有限,甚至可能略有下降,对最差性能的提升也不如
6.3.2. 对不同类型 LLM 的泛化能力
以下是原文 Figure 5 的示意图:

该图像是图表,展示了论文中方法在不同大语言模型(LLM)及多示例任务中的泛化性能。左图为3-shot情况下在Mistral-7B、Gemma-7B、Llama2-7B及Llama3-8B不同模型的性能提升,右图显示基于5-shot训练时,在多示例数量(8、16、32、64)和长序列(8k标记)条件下的扩展表现,柱状图区分了平均性能提升和最差性能提升。
Figure 5: Generalization performance of our method across different types of LLMs and many-shot settings. Left: Performance gains on 3-shot across different LLMs (Mistral-7B, Gemma-7B, Llama 2-7B, and Llama3-8B). Right: Scaling behavior across many-shot settings (8, 16, 32, and 64 shots) and longer sequences tokens) when trained with 5 shots and a sequence length of 512 tokens.
左图分析:3-shot 下不同 LLM 的性能增益
左图展示了 PEARL 在 3-shot 场景下对 Mistral-7B、Gemma-7B、Llama2-7B 和 Llama3-8B 等不同类型 LLM 的泛化性能。
- 一致的鲁棒性提升:
PEARL在所有测试的LLM上都一致地提高了最坏情况性能,提升幅度超过 10%。这证明了PEARL方法的通用适用性。 LLM家族敏感性差异: 不同的LLM家族对输入排列表现出不同程度的敏感性,其中Llama模型最为敏感,其次是Gemma和Mistral。尽管存在这些差异,但所有模型在最差情况下的性能下降普遍超过 10%,证实了排列敏感性是一个普遍存在的问题。PEARL的鲁棒性: 即使面对不同程度敏感性的模型,PEARL在 3shots及以上的情况下都能实现超过 10% 的最坏情况性能提升,展现了其强大的鲁棒性。
6.3.3. 在多示例和长序列场景下的可扩展性
右图分析:多示例和长序列的可扩展性行为
右图展示了 PEARL 在训练时仅使用 5 shots 和 512 tokens 序列长度的情况下,如何泛化到多示例 (many-shot ICL,最高 64 shots) 和长序列(最高 8,000 tokens)场景。
- 显著的泛化增益:
PEARL在泛化到更多shots和更长序列时,实现了 24% 到 40% 的显著最差情况性能增益。 - 高效性: 尽管训练配置较小,
PEARL仍能有效泛化到更大的规模,这突显了其高效性和强大的泛化能力。它使得LLM能够学习到鲁棒的特征,这些特征在多示例ICL和长序列处理中都表现出色。
6.3.4. 示例效率 (Shot Efficiency)
以下是原文 Table 4 的数据:
| # Shots | 2 | 4 | 8 | 16 | 32 | 64 |
|---|---|---|---|---|---|---|
| ERM | 57.3 | 59.7 | 61.8 | 66.9 | 67.4 | 68.1 |
| PEARL | 62.9 | 63.1 | 66.5 | 70.5 | 70.0 | 70.4 |
表 4 分析:有无 PEARL 的平均性能
- 更少的
shots达到相同性能:PEARL训练的模型在达到与基线ERM模型相当的平均性能时,所需的shots数量减少了 2 到 4 倍。例如,PEARL在 2shots时的平均性能(62.9)已经高于ERM在 8shots时的平均性能(61.8)。 - 效率提升: 这强调了
PEARL方法的效率,它使得LLM能够在更有限的上下文示例下达到高性能,这在实际应用中具有重要意义,可以减少推理成本和用户等待时间。
6.3.5. 扩展指令微调至不同 LLM (Appendix D)
附录 D 提供了 Mistral-7B、Gemma-7B、Llama2-7B 和 Llama2-13B 的详细指令微调结果,进一步支持了 PEARL 的泛化能力。
以下是原文 Table 7 的数据:
| Average | CSQA | CurDial | CoLA | TMW | |||||||
| # Shot | Method | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. |
| 2 | ERM PEARL | 64.1 67.0 (+4.5) | 58.1 62.4 (+7.5) | 67.0 68.0 | 64.0 66.0 | 54.6 59.4 | 41.8 49.0 | 81.0 82.0 | 78.0 78.0 | 53.7 58.4 | 48.5 56.7 |
| 3 | ERM PEARL | 66.6 69.5 (+4.3) | 56.1 62.8 (+12.0) | 67.0 70.0 | 62.0 66.0 | 63.7 70.1 | 38.9 60.1 | 80.0 83.6 | 76.0 78.0 | 55.6 54.1 | 47.3 47.0 |
| 4 | ERM PEARL | 66.7 68.3 (+2.5) | 50.4 57.1 (+13.4) | 68.9 69.9 | 60.0 62.0 | 67.6 71.6 | 47.8 54.8 | 74.2 74.9 | 52.0 66.0 | 55.9 56.8 | 41.6 45.5 |
| 5 | ERM PEARL | 67.9 70.2 (+3.4) | 50.7 58.1 (+14.5) | 67.5 70.4 | 56.0 64.0 | 70.7 76.7 | 52.6 59.3 | 76.0 73.3 | 56.0 66.0 | 57.4 60.4 | 38.2 43.0 |
表 7: Mistral-7B 指令微调结果
PEARL 在 Mistral-7B 上也显示出了一致的平均和最差性能提升,尤其在最差性能上,随着 shots 数量的增加,提升幅度从 7.5% 增加到 14.5%。
以下是原文 Table 8 的数据:
| Average | CSQA | CurDial | CoLA | TMW | |||||||
| # Shot | Method | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. |
| 2 | ERM PEARL | 66.2 66.3 (+0.0) | 59.5 60.7 (+2.0) | 71.0 74.0 | 70.0 68.0 | 59.1 47.3 | 46.1 39.2 | 77.0 82.0 | 70.0 78.0 | 57.8 61.7 | 52.0 57.6 |
| 3 | ERM PEARL | 64.7 68.4 (+5.8) | 52.5 59.3 (+13.0) | 70.7 74.7 | 64.0 68.0 | 67.1 59.2 | 45.2 42.5 | 70.3 78.7 | 60.0 76.0 | 50.5 61.0 | 40.7 50.6 |
| 4 | ERM PEARL | 65.0 67.2 (+3.4) | 46.5 52.5(+13.0) | 65.0 71.4 | 54.0 60.0 | 71.4 60.7 | 41.1 38.9 | 72.5 75.9 | 58.0 66.0 | 51.1 60.8 | 32.9 45.2 |
| 5 | ERM PEARL | 64.3 66.3 (+3.1) | 46.3 51.0 (+10.2) | 65.9 70.3 | 54.0 60.0 | 73.4 63.4 | 48.3 43.6 | 65.6 71.3 | 50.0 60.0 | 52.3 60.2 | 32.9 40.4 |
表 8: Gemma-7B 指令微调结果
PEARL 在 Gemma-7B 上同样带来了性能提升,特别是在 3 shots 和 4 shots 场景下,最差性能增益达到 13.0%。
以下是原文 Table 9 的数据:
| Average | CSQA | CurDial | CoLA | TMW | |||||||
| # Shot | Method | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. |
| 2 | ERM PEARL | 56.6 57.4 (+1.5) | 46.3 46.5 (+0.4) | 56.0 58.0 | 50.0 48.0 | 61.3 55.2 | 50.2 44.7 | 58.2 62.0 | 42.0 48.0 | 50.7 54.4 | 43.1 45.4 |
| 3 | ERM PEARL | 58.2 59.6 (+2.3) | 34.0 40.4 (+19.1) | 52.7 56.3 | 34.0 40.0 | 64.0 66.2 | 36.4 46.2 | 66.0 67.0 | 36.0 42.0 | 50.1 48.7 | 29.4 33.5 |
| 4 | ERM PEARL | 58.9 60.5 (+2.7) | 19.9 31.6 (+59.1) | 60.0 61.2 | 26.0 40.0 | 68.1 69.4 | 24.4 40.1 | 60.2 62.4 | 14.0 24.0 | 47.3 48.9 | 15.1 22.4 |
| 5 | ERM PEARL | 61.9 62.9 (+1.6) | 25.8 32.1 (+24.7) | 59.0 62.4 | 32.0 38.0 | 74.2 73.3 | 43.9 43.4 | 65.7 64.8 | 10.0 24.0 | 48.6 51.0 | 17.1 23.0 |
表 9: Llama2-7B 指令微调结果
在 Llama2-7B 上,PEARL 的最差性能提升尤为显著,在 4 shots 场景下高达 59.1%,表明 PEARL 对这类敏感模型具有强大的纠正能力。
以下是原文 Table 10 的数据:
| Average | CSQA | CurDial | CoLA | TMW | |||||||
| # Shot | Method | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. |
| 2.0 | ERM PEARL | 66.3 67.9 (+2.4) | 56.6 60.7 (+7.3) | 56.0 64.0 | 46.0 58.0 | 72.6 73.8 | 56.2 64.2 | 83.0 81.0 | 76.0 76.0 | 53.4 52.6 | 48.0 44.4 |
| 3.0 | ERM PEARL | 65.7 68.5 (+4.2) | 46.2 50.3 (+8.7) | 55.7 62.7 | 38.0 44.0 | 76.4 81.0 | 51.3 58.4 | 77.7 76.7 | 56.0 56.0 | 53.1 53.5 | 39.6 42.6 |
| 4.0 | ERM PEARL | 65.8 66.4 (+0.9) | 33.2 40.2 (+21.1) | 58.2 63.3 | 28.0 42.0 | 79.6 80.4 | 41.6 45.5 | 73.7 69.4 | 38.0 42.0 | 51.8 53.1 | 25.0 29.1 |
表 10: Llama2-13B 指令微调结果
PEARL 在 Llama2-13B 上同样表现出稳健的最差性能提升,在 4 shots 场景下达到 21.1%。
总结泛化能力: PEARL 方法在三种或更多示例的情况下,持续展示出显著的性能提升,最差情况性能通常提升超过 10%。这进一步证实了该方法的鲁棒性和有效性,能够适应不同 LLM 家族的特点。
6.3.6. 扩展到多示例上下文学习 (Appendix E)
附录 E 评估了 PEARL 在多示例场景下的可扩展性,测试了 8 到 64 个上下文示例的性能。
以下是原文 Table 11 的数据:
| Average | CSQA | CurDial | CoLA | TMW | |||||||
| # Shot | Method | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. | Avg. | Worst. |
| 8 | ERM PEARL | 61.8 66.5 (+7.6) | 21.3 29.7 (+39.2) | 61.4 67.7 | 36.0 44.0 | 68.3 77.1 | 22.7 28.7 | 62.7 65.0 | 16.0 32.0 | 54.8 56.2 | 10.6 14.0 |
| 16 | ERM PEARL | 66.9 70.5 (+5.3) | 21.3 | 67.3 | 36.0 | 76.5 | 31.4 | 67.2 | 8.0 | 56.5 56.9 | 9.7 9.8 |
| 32 | ERM | 67.4 | 26.3 (+23.7) 19.3 | 70.9 67.5 | 46.0 32.0 | 83.9 77.8 | 37.5 30.7 | 70.1 68.2 | 12.0 6.0 | 56.1 | 8.6 |
| 64 | PEARL ERM PEARL | 70.0 (+3.8) 68.1 | 26.4 (+36.4) 20.6 | 70.0 68.1 | 44.0 38.0 | 82.6 76.9 | 40.3 27.7 | 70.6 72.2 | 12.0 8.7 | 56.6 55.0 | 9.1 8.0 8.1 |
表 11: 8-、16-、32- 和 64-shot 设置下的性能评估
- 训练与泛化差异: 尽管
PEARL仅使用 5shots进行训练,但在 8 到 64shots的多示例场景下,它依然展现出强大的泛化能力。 - 持续的性能优势: 在所有
shots数量下,PEARL始终优于ERM基线。最差情况性能增益尤其显著,在 8shots时高达 39.2%,在 64shots时仍能达到 36.4%。这表明PEARL能够有效地将学到的鲁棒性推广到训练时未见的、更大规模的上下文场景。
6.3.7. 最优性能 (Best-Case Performance) (Appendix F)
附录 F 还提供了一个对最优性能的评估,以提供一个平衡的视角。
以下是原文 Table 12 的数据:
| #Shot | Method | Average | Gain | CSQA | CurDial | CoLA | TMW |
|---|---|---|---|---|---|---|---|
| 2 | ERM | 64.1 | - | 68.8 | 64.4 | 64.1 | 59.2 |
| PEARL | 68.8 | 7.2% | 73.4 | 69.2 | 70.3 | 62.1 | |
| 3 | ERM | 72.8 | - | 70.3 | 85.0 | 65.6 | 70.3 |
| PEARL | 77.0 | 5.7% | 73.4 | 87.9 | 79.7 | 66.9 | |
| 4 | ERM | 82.9 | - | 81.3 | 92.4 | 78.1 | 79.7 |
| PEARL | 84.3 | 1.7% | 82.8 | 93.6 | 81.2 | 79.5 | |
| 5 | ERM | 86.8 | - | 84.4 | 95.3 | 81.3 | 86.2 |
| PEARL | 89.3 | 2.9% | 87.5 | 96.5 | 85.9 | 87.3 |
表 12: ERM 与 PEARL 的最优性能比较
PEARL提升最优性能: 令人惊讶的是,PEARL不仅在最坏情况场景下表现出色,在所有数据集和所有shot条件下,PEARL的最优性能也始终超过ERM。- 全面性能提升: 这表明
PEARL不仅能够通过优化最坏情况来提高模型下限,还能略微提升模型的上限性能,使其在不同排列组合下都能有更好的表现。
6.3.8. 超参数分析 (Analysis Of Hyperparameters In Instruction Finetuning) (Appendix C)
以下是原文 Figure 7 的示意图:

该图像是图表,展示了梯度范数与熵系数β之间的关系。图中曲线分别表示P-Net和LLM的梯度范数随β变化的趋势,反映了熵系数对模型训练过程的影响。
Figure 7: Impact of entropy coefficient.
图 7: 熵系数的影响
-
低熵系数 (0.3):
P-Net的梯度范数很小,表明学习不足。P-Net可能只生成简单的语义重叠排列来满足对抗性训练,导致LLM梯度范数难以降低。 -
中等熵系数 (1.0-3.0):
P-Net的梯度范数在 1.0 处达到峰值,表明此时学习效果最佳。这个范围提供了P-Net在从LLM中提取有意义信息与避免过度简化或复杂化任务之间的理想平衡。 -
高熵系数 (3.0 和 10.0):
P-Net的梯度范数再次下降,表明限制过多。 -
LLM梯度范数: 随着熵系数的增加,LLM的梯度范数持续下降,表明LLM对熵正则化有明确的响应。以下是原文 Table 6 的数据:
# Iter. Temperature 0.03 0.1 0.3 80 55.7 / 40.0 55.7 / 40.0 55.4 / 39.6 200 55.7 / 40.0 55.8 / 40.0 55.8 / 40.6
表 6: Sinkhorn 算法迭代次数和温度的影响
- 参数鲁棒性: 在熵正则化系数固定为 1 的情况下,即使
Sinkhorn算法的迭代次数(80, 200)和温度(0.03, 0.1, 0.3)发生了显著变化,模型的平均/最差性能(以Avg. / Worst.形式表示)变化却微乎其微。 - 结论: 这表明
PEARL框架中的Sinkhorn算法对这些超参数的敏感性低于预期,可能意味着在实际应用中具有更广泛的稳定配置范围。
6.4. 总结
总的来说,所有实验结果都强有力地支持了 PEARL 框架的有效性和优越性。它不仅显著增强了 LLM 对输入排列的鲁棒性,有效抵御了排列攻击,还在平均性能、示例效率和泛化能力方面超越了现有基线方法。PEARL 通过在训练阶段采用 DRO 和对抗性排列生成,为构建更安全、更可靠的 LLM 提供了重要途径。
7. 总结与思考
7.1. 结论总结
本文介绍了 PEARL (Permutation-resilient learning) 框架,旨在增强大型语言模型 (LLMs) 对输入示例排列顺序变化的鲁棒性。PEARL 的核心在于结合了分布鲁棒优化 (DRO) 和一个专门的排列提议网络 (P-Net)。P-Net 利用熵约束 Sinkhorn 算法,将生成最具挑战性的排列建模为一个最优传输问题。通过最小最大优化的对抗性训练,P-Net 学习生成能最大化 LLM 损失的排列,而 LLM 则学习在这些最坏情况排列下最小化损失,从而逐步提升其鲁棒性。
实验结果在合成预训练任务和真实世界的指令微调任务上都验证了 PEARL 的有效性。PEARL 能够显著缓解排列攻击,并在平均和最坏情况性能上均超越了现有基线方法。特别是,它在面对未见的排列时表现出更强的泛化能力,并且在扩展到多示例和长上下文场景时,即使在较少示例和较短上下文下训练,也能实现高达 40% 的性能增益,同时显著提高了示例效率。
7.2. 局限性与未来工作
7.2.1. 局限性
- P-Net 的额外开销(训练阶段): 尽管
PEARL在推理时没有额外开销,但在训练阶段,需要训练一个额外的P-Net。虽然P-Net相对较小,且采用了LoRA进行高效微调,但它仍然引入了额外的计算复杂性。对于资源受限的场景,这可能是一个考量因素。 - 熵正则化项的调优: 熵正则化系数 的选择对
P-Net的学习行为有影响。虽然论文进行了超参数分析,但在更复杂的任务或更大规模的模型上,其调优可能仍需仔细探索。 - 最优传输的近似性:
Sinkhorn算法是最优传输问题的近似解,而Gumbel技巧也引入了近似。这些近似可能在一定程度上影响P-Net生成“真正最坏情况”排列的能力。 - 排列的复杂性: 论文主要关注示例对的顺序排列。然而,在更广泛的上下文中,
prompt中还可能存在其他形式的结构变化(例如,示例内容的修改,指令本身的微调),PEARL目前主要解决了顺序问题。
7.2.2. 未来工作
- 扩展到其他结构化输入: 论文指出,
PEARL可以作为一个通用框架,用于处理具有顺序无关元素的集合结构输入,例如多个文档、图像或视频。未来可以探索PEARL在这些非文本模态或多模态场景下的应用。 - P-Net 的效率和架构探索: 进一步优化
P-Net的架构和训练效率,例如,探索更轻量级或更通用(task-agnostic)的P-Net设计,以减少训练开销并提高泛化能力。 - 与其他鲁棒性技术的结合: 将
PEARL与其他LLM鲁棒性增强技术(例如,对抗性示例训练、数据增强方法)结合,以实现更全面的模型保护。 - 理论分析: 对
DRO在ICL排列鲁棒性中的理论保证进行更深入的分析,例如,量化模糊集的大小和特性如何影响模型的泛化界限。 - 更复杂的攻击模式: 探索除了简单排列之外的更复杂攻击模式(例如,插入、删除、修改示例内容),并研究
PEARL或其扩展版本如何应对这些更具挑战性的场景。
7.3. 个人启发与批判
7.3.1. 个人启发
- DRO 在 LLM 鲁棒性中的潜力: 本文首次将
DRO引入LLM的ICL鲁棒性问题,开辟了一个新的研究方向。DRO这种关注最坏情况的优化范式,对于LLM在安全、公平和可靠性等方面的挑战具有巨大的潜力。未来的LLM训练可能不仅仅追求平均性能,更会通过DRO等方法来提升其在各种边缘和对抗性场景下的表现。 - 训练阶段解决根本问题:
PEARL的训练阶段方法比推理阶段的后处理更具根本性。它使得模型本身具备了抵抗排列敏感性的能力,避免了额外的推理开销,这对于LLM的实际部署至关重要。这强调了在模型设计和训练阶段就考虑鲁棒性的重要性。 - 对抗性训练的精妙应用:
P-Net和LLM之间的最小最大游戏设计非常巧妙,它模拟了攻击者和防御者之间的博弈,使得模型能够主动学习和适应那些最容易导致其失败的输入模式。这种动态的对抗机制远比静态的数据增强方法更有效。 - 最优传输的新颖结合: 将最优传输用于生成对抗性排列是一种创新。它提供了一种数学上严谨且可微分的方法来探索和利用排列空间,解决了传统方法在排列数量呈指数级增长时面临的计算挑战。
7.3.2. 批判
- P-Net 的可解释性: 尽管
P-Net能够生成挑战性排列,但这些排列背后的具体“为什么”仍然不够直观。例如, 值代表“交换 和 可能对LLM任务难度增加的潜力”,但这种“潜力”的具体语义含义是什么?是结构上的混乱?还是语义上的误导?如果能更深入地分析P-Net生成的“最坏排列”的特征,将有助于理解LLM失败的深层原因。 - 最优传输与排列的映射:
最优传输主要处理两个分布之间的软映射。将其应用于生成离散的排列矩阵,虽然通过Sinkhorn和Gumbel技巧实现了可微分性,但这种映射的理论完备性或最佳性仍值得深入探讨。例如,P-Net真的能找到所有可能的“最坏情况”排列吗?或者它只是找到了一个局部最优解? - 通用性限制: 尽管论文声称
PEARL可以应用于其他集合结构输入,但其核心P-Net的设计(如特征提取器和关系建模)是为文本序列中的ICL示例定制的。对于图像、视频等多模态数据,P-Net的具体实现可能需要进行大幅修改,其通用性仍需进一步验证。 LLM架构不变性:PEARL的优势在于不修改LLM架构,但这也意味着它没有从根本上解决Transformer自回归架构带来的某些固有局限性。未来的研究可以探索如何在保持PEARL优势的同时,结合一些轻量级的架构调整,以实现更深层次的鲁棒性。- 超参数的敏感性: 尽管
Sinkhorn参数不敏感,但熵系数 的选择仍然关键。在不同任务、不同LLM模型或不同shot数量下,beta的最优值可能不同,需要进行仔细调优。这种调优过程本身可能会增加训练复杂性。
相似论文推荐
基于向量语义检索推荐的相关论文。