论文状态:已完成

SampleAttention: Near-Lossless Acceleration of Long Context LLM Inference with Adaptive Structured Sparse Attention

发表:2024/06/17
原文链接PDF 下载
价格:0.100000
价格:0.100000
价格:0.100000
已有 1 人读过
本分析由 AI 生成,可能不完全准确,请以原文为准。

TL;DR 精炼摘要

大型语言模型支持超长上下文,但传统注意力机制导致显著的时间延迟。本文提出了SampleAttention,一种自适应的近乎无损稀疏注意力方法,通过动态捕获稀疏模式,显著降低了推理延迟,同时保持模型的准确性。评估表明,该方法在TTFT上减少高达2.42倍,几乎不损失准确性。

摘要

Large language models (LLMs) now support extremely long context windows, but the quadratic complexity of vanilla attention results in significantly long Time-to-First-Token (TTFT) latency. Existing approaches to address this complexity require additional pretraining or finetuning, and often sacrifice model accuracy. In this paper, we first provide both theoretical and empirical foundations for near-lossless sparse attention. We find dynamically capturing head-specific sparse patterns at runtime with low overhead is crucial. To address this, we propose SampleAttention, an adaptive structured and near-lossless sparse attention. Leveraging observed significant sparse patterns, SampleAttention attends to a fixed percentage of adjacent tokens to capture local window patterns, and employs a two-stage query-guided key-value filtering approach, which adaptively select a minimum set of key-values with low overhead, to capture column stripe patterns. Comprehensive evaluations show that SampleAttention can seamlessly replace vanilla attention in off-the-shelf LLMs with nearly no accuracy loss, and reduces TTFT by up to 2.42×2.42\times compared with FlashAttention.

思维导图

论文精读

中文精读

1. 论文基本信息

1.1. 标题

SampleAttention: 针对长上下文大语言模型推理的近乎无损的自适应结构化稀疏注意力加速 (SampleAttention: Near-Lossless Acceleration of Long Context LLM Inference with Adaptive Structured Sparse Attention)

1.2. 作者

Qianchao Zhu, Jiangfei Duan, Chang Chen, Xiuhong Li, Siran Liu, Guanyu Feng, Xin Lv, Chuanfu Xiao, Dahua Lin, Chao Yang

1.3. 发表期刊/会议

该论文作为预印本 (preprint) 发布在 arXiv 上。arXiv 是一个收录物理学、数学、计算机科学、量化生物学、量化金融、统计学、电气工程与系统科学、经济学等领域科学论文预印本的网站。在计算机科学领域,许多重要的研究成果在正式发表前都会首先在 arXiv 上发布,因此其影响力广泛且受到学术界的认可。

1.4. 发表年份

2024年 (具体发布时间为 2024-06-17T11:05:15.000Z)

1.5. 摘要

大型语言模型 (LLMs) 现在支持极长的上下文窗口,但传统的注意力机制的二次复杂度导致首次生成词元时间 (Time-to-First-Token, TTFT) 延迟显著。现有解决此复杂性的方法通常需要额外的预训练或微调 (finetuning),并且常常牺牲模型准确性。本文首先为近乎无损的稀疏注意力提供了理论和经验基础。研究发现,在运行时以低开销动态捕获头部特有的稀疏模式至关重要。为此,本文提出了 SampleAttention,一种自适应结构化且近乎无损的稀疏注意力方法。利用观察到的显著稀疏模式,SampleAttention 通过关注固定百分比的相邻词元来捕获局部窗口模式,并采用两阶段查询引导的键值 (key-value) 过滤方法,以低开销自适应地选择最小的关键键值集,从而捕获列条带模式。全面的评估表明,SampleAttention 可以在现成的 LLMs 中无缝替代传统注意力,几乎不损失准确性,并且与 FlashAttention 相比,TTFT 减少高达 2.42×2.42\times(摘要原文,正文提升至 5.29×5.29\times)。

1.6. 原文链接

  • 原文链接: https://arxiv.org/abs/2406.15486
  • PDF 链接: https://arxiv.org/pdf/2406.15486v3.pdf

2. 整体概括

2.1. 研究背景与动机

大型语言模型 (LLMs) 在处理复杂应用(如文档分析、代码辅助编程、长时间对话)方面取得了显著进展,这得益于它们支持的日益增长的上下文窗口。目前流行的 LLMs,如 GeminiClaudeKimi,甚至能够支持超过一百万词元 (token) 的上下文长度。

然而,这种上下文长度的显著增加也带来了一个核心挑战:传统注意力机制的二次复杂度。注意力机制的计算成本和内存占用与序列长度的平方成正比。这意味着当处理极长序列时,模型生成第一个词元的时间 (Time-to-First-Token, TTFT) 会急剧增加,形成严重的性能瓶颈。例如,论文指出在一个一百万词元的上下文情境下,ChatGLM3-6B 模型在 A100 GPU 上进行注意力计算需要 1555 秒,这占到了 TTFT 的 90% 以上。这种高延迟使得 LLMs 在需要实时响应的实际应用中变得不切实际。

现有解决这一问题的方法主要有两类:

  1. 静态稀疏注意力 (Static Sparse Attention): 这类方法预定义了稀疏模式(如 LongFormerBigBird),但它们往往无法捕获注意力模式的动态变化,从而牺牲了模型准确性。

  2. 动态稀疏注意力 (Dynamic Sparse Attention): 这类方法试图在运行时识别重要的注意力部分。例如,DSA 引入了高计算开销,HyperAttention 的粗粒度选择方法难以保持准确性,而 MInference 虽然考虑了列条带 (column stripe) 和斜杠条带 (slash stripe) 模式,但其依赖预定义模式和固定预算,未能充分适应注意力头 (attention head) 之间、输入内容之间以及模型架构之间变化的稀疏比率和动态模式。

    因此,当前研究面临的挑战在于:

  • 自适应稀疏比率 (Adaptive Sparsity Ratio): 最佳稀疏比率在不同注意力头、输入内容和模型架构之间是动态变化的,需要运行时确定。

  • 动态稀疏模式 (Dynamic Sparse Pattern): 注意力模式本身也是动态变化的,通常是列模式(全局上下文信息)和斜杠模式(局部窗口信息)的复杂组合,这使得精确选择稀疏模式变得困难。

    这些挑战共同构成了现有方法难以在效率和准确性之间取得最佳平衡的根本原因,促使研究人员寻求一种更具适应性、运行时高效的方法来同时确定稀疏比率和模式。

2.2. 核心贡献/主要发现

本文通过提出 SampleAttention 框架,旨在解决长上下文 LLM 推理中的 TTFT 延迟问题,同时保持近乎无损的模型准确性。其核心贡献和主要发现包括:

  1. 理论和经验基础 (Theoretical and Empirical Foundations): 提供了近乎无损稀疏注意力的理论和经验基础,强调了在运行时以低开销动态捕获头部特有稀疏模式的重要性。
  2. 提出 SampleAttention 方法 (Introduction of SampleAttention):
    • 自适应结构化稀疏注意力 (Adaptive Structured Sparse Attention): SampleAttention 能够动态地确定稀疏比率和模式,解决了现有方法在处理动态稀疏性方面的不足。
    • 两阶段查询引导键值过滤 (Two-stage Query-Guided Key-Value Filtering): 设计了一个创新的两阶段过滤算法,用于高效识别重要的列条带和斜杠条带模式。
      • 第一阶段:查询引导分块采样 (Query-Guided Chunked Sampling): 通过对查询进行分块采样,更准确地估计整个注意力分数矩阵的稀疏结构,克服了仅采样末尾查询的局限性。
      • 第二阶段:基于得分的键值过滤 (Score-Based Key-Value Filtering): 基于采样分数和 CRA 阈值,独立选择关键的列和斜杠条带,显著降低了计算复杂性。
  3. 引入累积残差注意力 (Cumulative Residual Attention, CRA) 指标: 提出 CRA 作为衡量模型准确性和注意力召回率的稳健指标。研究发现 CRA 阈值与模型准确性之间存在一致的正相关性,为在效率和准确性之间做出权衡提供了原则性的指导。
  4. 自动超参数调优方法 (Automated Hyperparameter Tuning): 提供了一种自动调优超参数 (αcαc, αsαs, chunkn) 的方法,使用小规模评测数据集为不同长度范围内的模型确定最优设置,从而在保持准确性的同时最大化计算效率。
  5. 硬件高效实现 (Hardware-efficient Implementation): SampleAttention 通过融合操作符和修改 FlashAttention2 内核实现,具有 IO 感知能力,显著提升了实际运行速度。
  6. 卓越的实验结果 (Outstanding Experimental Results):
    • 近乎无损的准确性 (Near-Lossless Accuracy):ChatGLMYIInternLM 等模型上,SampleAttentionRULERLongBenchInfiniteBench 等多个长上下文基准测试中表现出与全注意力 (Full Attention) 几乎一致的准确性(通常达到 99% 以上)。
    • 显著的 TTFT 加速 (Significant TTFT Acceleration):FlashAttention2 相比,SampleAttention 实现了高达 5.29×5.29\timesTTFT 减少,尤其在一百万词元的超长序列上效果显著。
    • 建立新的帕累托前沿 (New Pareto Frontier): 在准确性-效率权衡方面,SampleAttention 超越了现有方法,建立了新的帕累托前沿。

3. 预备知识与相关工作

3.1. 基础概念

为了理解 SampleAttention 的工作原理,需要首先理解大型语言模型 (LLMs) 的基本架构、推理过程以及核心的注意力机制。

3.1.1. 大型语言模型 (LLMs)

LLMs 是基于 Transformer 架构的深度学习模型,通过在海量文本数据上进行预训练,学习语言的统计规律和语义信息。它们能够理解、生成和处理人类语言,并在各种自然语言处理任务中表现出卓越的性能。

3.1.2. Transformer 架构

Transformer (Vaswani et al., 2017) 是 LLMs 的核心。它由多个编码器 (encoder) 和解码器 (decoder) 堆叠而成,但 LLMs 通常采用仅解码器 (decoder-only) 架构。每个解码器层包含一个自注意力 (Self-Attention) 层和一个前馈网络 (Feed-Forward Network, MLP)。自注意力机制允许模型在处理序列中的每个词元时,关注序列中的其他所有词元,从而捕获长距离依赖关系。

3.1.3. LLM 推理过程

LLM 的推理过程分为两个阶段:

  • 预填充阶段 (Prefill Phase): 模型并行处理整个输入提示 (prompt),并生成第一个输出词元。在此阶段,还会为提示中的每个词元生成并存储键值 (Key-Value, KV) 缓存,以供后续计算使用。
  • 解码阶段 (Decoding Phase): 在预填充阶段之后,模型根据所有先前的词元顺序生成每个新词元。模型每次接收一个输出词元作为输入,并利用 KV 缓存自回归地生成下一个新词元。新生成的词元的 KV 缓存也会被保存,作为后续生成的上下文。 本论文主要关注的是预填充阶段的优化,因为在长序列处理中,预填充阶段的计算量巨大,导致 TTFT 延迟成为瓶颈。

3.1.4. 注意力机制 (Attention Mechanism)

注意力机制是 Transformer 的核心。给定查询 (Query, Q\mathbf{Q})、键 (Key, K\mathbf{K}) 和值 (Value, V\mathbf{V}) 向量,注意力机制计算每个查询与所有键的相似度,然后将这些相似度作为权重,对所有值进行加权求和,得到输出。

对于一个注意力头,输入为查询 (Query) 矩阵 QRSq×d\mathbf{Q} \in \mathbb{R}^{S_q \times d},键 (Key) 矩阵 KRSk×d\mathbf{K} \in \mathbb{R}^{S_k \times d},以及值 (Value) 矩阵 VRSk×d\mathbf{V} \in \mathbb{R}^{S_k \times d},其中 SqS_q 是查询序列长度,SkS_k 是键值序列长度,dd 是头维度 (head dimension)。标准的注意力计算公式为:

P=softmax(QKTd)[0,1]Sq×Sk \mathbf{P} = \mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}}\right) \in [0, 1]^{S_q \times S_k}

O=PVRSq×d \mathbf{O} = \mathbf{P}\mathbf{V} \in \mathbb{R}^{S_q \times d}

其中:

  • QKT\mathbf{Q}\mathbf{K}^T: 计算查询和键之间的相似度分数矩阵。

  • d\sqrt{d}: 缩放因子,用于防止点积结果过大,导致 softmax 函数梯度过小。

  • softmax\mathrm{softmax}: 行式 (row-wise) 应用的 softmax 函数,将分数归一化为概率分布 P\mathbf{P}

  • P\mathbf{P}: 注意力分数 (attention score) 矩阵,其元素表示每个查询对每个键的关注程度。

  • O\mathbf{O}: 注意力输出矩阵,是加权后的值向量之和。

    挑战: 在长序列中,注意力分数矩阵 P\mathbf{P} 的内存占用和计算复杂度都与序列长度的平方 (Sq×SkS_q \times S_k) 成正比。尽管 FlashAttention (Dao et al., 2022) 通过在线 softmax 计算有效解决了内存瓶颈,但二次方的计算复杂度问题仍然存在,导致长上下文推理的 TTFT 居高不下。

3.1.5. 键值缓存 (KV Cache)

在自回归生成过程中,为了避免重复计算已经生成词元的键和值,Transformer 模型会存储这些词元的 KV 向量,形成 KV 缓存。在生成每个新词元时,新的查询向量只需要与 KV 缓存中的所有键向量计算注意力,从而避免了对整个历史序列的重复键值计算。然而,KV 缓存的内存消耗也会随着序列长度的增加而线性增长,成为另一个挑战。

3.1.6. 稀疏注意力 (Sparse Attention)

鉴于全注意力 (Full Attention) 的高计算成本,稀疏注意力旨在仅计算注意力分数矩阵 P\mathbf{P} 中的一部分重要元素,从而降低复杂性。其形式化表示为:

P^=softmax(QKTdc(1M)) \hat{\mathbf{P}} = \mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}} - c(1 - \mathbf{M})\right)

其中:

  • M{0,1}Sq×Sk\mathbf{M} \in \{0, 1\}^{S_q \times S_k} 是一个二值掩码矩阵 (binary mask matrix),用于确定稀疏注意力模式。
  • cc 是一个大的常数,它在 softmax 操作后有效地将被掩码的注意力分数归零。 稀疏比率 (sparsity ratio) 衡量被掩码的注意力分数所占的百分比。论文观察到 LLMs 本身就具有很高的稀疏性(平均稀疏比率超过 89.6%),这为稀疏注意力提供了基础。

3.2. 前人工作

现有研究在解决长上下文 LLM 推理效率方面主要有以下几类方法:

3.2.1. 近似注意力 (Approximate Attention)

这类方法旨在通过各种机制降低注意力机制的二次复杂度。

  • 静态稀疏模式 (Static Sparse Patterns):

    • LongFormer (Beltagy et al., 2020) 结合了局部窗口注意力 (window attention) 和全局注意力 (global attention)。
    • BigBird (Zaheer et al., 2020) 进一步引入了随机注意力 (random attention) 来增强长距离依赖捕获。
    • LongNet (Ding et al., 2023) 使用膨胀注意力 (dilated attention) 替代全注意力。
    • StreamingLLM (Xiao et al., 2023b) 通过保留注意力池 (attention sink) 和最近的词元来实现无限长度的生成。
    • 局限性: 这些方法使用预定义的静态稀疏模式,无法捕获注意力头和输入内容之间动态变化的稀疏模式,因此在不进行额外微调或预训练的情况下,往往无法达到全注意力的准确性。
  • 动态稀疏模式 (Dynamic Sparse Patterns):

    • Reformer (Kitaev et al., 2020) 通过局部敏感哈希 (Locality Sensitive Hashing, LSH) 减少计算成本。
    • Linformer (Wang et al., 2020) 使用低秩矩阵近似注意力。
    • HyperAttention (Han et al., 2023) 利用 LSH 识别注意力图中的重要条目,但其粗粒度的选择方法难以保持模型准确性。
    • DSA (Liu et al., 2022) 使用低秩隐维度近似注意力模式,但在长上下文下引入了显著的计算开销。
    • MInference (Jiang et al., 2024) 尝试通过预定义几种稀疏模式(A-shape, Vertical-Slash, Block-Sparse),并在离线阶段将注意力头匹配到最优模式,然后在预填充阶段动态搜索稀疏索引。局限性: 这种方法依赖预定义模式和固定预算,未能充分适应不同注意力头之间变化的稀疏比率,以及输入内容导致的动态稀疏模式。
  • 其他近似方法:FlashAttention (Dao et al., 2022) 解决了内存瓶颈,但二次方的计算复杂度仍未解决。

3.2.2. KV 缓存压缩 (KV Cache Compression)

这类方法主要关注减少 KV 缓存的内存消耗,通常在解码阶段发挥作用。

  • StreamingLLM (Xiao et al., 2023b) 保持注意力池和少量最新词元。
  • H2O (Zhang et al., 2024c) 根据注意力分数动态保留最近和“重击手” (heavy hitter) 词元。
  • FastGen (Ge et al., 2023) 根据观察到的头部特定策略自适应地构建 KV 缓存。
  • SnapKV (Li et al., 2024) 通过选择每个注意力头的关键聚类位置来压缩 KV 缓存。
  • CHAI (Agarwal et al., 2024) 剪枝 (pruning) 注意力头以降低 KV 缓存开销。
  • KV 缓存量化 (quantization) (Duanmu et al., 2024; Xiao et al., 2023a; Zhao et al., 2023) 也被用于降低内存消耗。
  • 与本文关系: SampleAttention 关注的是预填充阶段的计算开销,而 KV 缓存压缩关注解码阶段的内存消耗。两者是正交且可以结合使用的,以进一步优化性能。

3.3. 技术演进

从全注意力到稀疏注意力的技术演进可以概括为:

  1. 全注意力 (Vanilla Attention): 提供最强的表达能力,但计算复杂度高 (O(N2)O(N^2))。
  2. 静态稀疏注意力 (Static Sparse Attention): 通过预设规则(如窗口、全局、膨胀注意力)降低复杂度 (O(N)O(N)O(NlogN)O(N \log N)),但难以适应动态内容,通常需要重新训练或微调。
  3. 动态稀疏注意力 (Dynamic Sparse Attention): 尝试在运行时识别重要注意力区域,以期在效率和准确性之间取得更好的平衡。早期的 LSH 方法可能过于粗粒度。MInference 等工作开始关注注意力模式的结构化特点(列条带、斜杠条带)。
  4. SampleAttention 的位置: SampleAttention 建立在动态稀疏注意力的基础上,并进一步认识到稀疏比率和模式的高度自适应性。它通过引入 CRA 指标和两阶段查询引导过滤,旨在更精确、更高效地在运行时捕获这种动态稀疏性,从而在不牺牲准确性的前提下实现显著加速。

3.4. 差异化分析

SampleAttention 与现有工作的主要区别和创新点体现在以下几个方面:

  • 自适应稀疏比率与模式 (Adaptive Sparsity Ratio and Pattern):

    • 现有工作: MInference (Jiang et al., 2024) 和 DuoAttention (Xiao et al., 2024) 采用固定稀疏预算或预定义模式,无法有效适应注意力头、输入内容和模型架构之间变化的稀疏比率和动态模式。
    • SampleAttention: 强调并解决了稀疏比率在头部、内容和模型之间的自适应性。它在运行时动态确定每个注意力头和输入提示的最优稀疏比率,并通过两阶段过滤自适应地选择列条带和斜杠条带模式,能够捕获更复杂的混合模式。
  • CRA 指导下的准确性-效率权衡 (CRA-guided Accuracy-Efficiency Trade-off):

    • 现有工作: 通常缺乏一个通用的、可量化的指标来指导稀疏化决策,导致在效率提升的同时难以保证准确性。
    • SampleAttention: 引入了 Cumulative Residual Attention (CRA) 作为衡量模型准确性和注意力召回率的稳健指标。CRA 阈值与模型准确性之间的一致正相关性提供了一个原则性的方法来平衡效率和准确性,允许动态调整稀疏化程度。
  • 低开销的运行时模式选择 (Low-Overhead Runtime Pattern Selection):

    • 现有工作: 动态方法(如 DSA)可能引入显著的计算开销,而另一些方法(如 MInference)的模式匹配依赖离线分析。
    • SampleAttention: 提出了高效的两阶段查询引导键值过滤。第一阶段的分块采样以低开销估计注意力分数,克服了现有方法(如 MInference 的仅末尾查询采样)的局限性。第二阶段的独立过滤(列和斜杠)将复杂度从乘法降低到加法,进一步提高了效率。
  • 无需额外预训练/微调 (No Additional Pretraining/Finetuning Required):

    • 现有工作: 许多静态稀疏注意力方法(如 LongFormerBigBird)需要额外的预训练或微调才能保持准确性。
    • SampleAttention: 旨在作为现成 LLMs 的即插即用替代品,几乎不损失准确性,因为其核心思想是精确地近似全注意力,而非改变模型行为。
  • 自动化超参数调优 (Automated Hyperparameter Tuning):

    • 现有工作: 稀疏注意力方法通常涉及多个超参数,手动调优复杂。
    • SampleAttention: 提供了一种自动化的离线调优方法,使用小规模数据集为不同长度范围的序列确定最优超参数,简化了部署和使用。

4. 方法论

4.1. 方法原理

SampleAttention 的核心思想是动态地、以低开销地在运行时识别和选择注意力分数矩阵中最关键的“列条带”和“斜杠条带”模式,以近似全注意力,从而显著降低计算复杂度,同时通过 Cumulative Residual Attention (CRA) 指标保证近乎无损的准确性

论文观察到 LLMs 注意力机制固有的高稀疏性,但这种稀疏性并非均匀分布,而是呈现出自适应的稀疏比率(因注意力头、输入内容和模型架构而异)和动态的稀疏模式(通常是捕获全局上下文的“列模式”和捕获局部上下文的“斜杠模式”的组合)。现有的静态或粗粒度稀疏注意力方法难以捕捉这种动态性。

为了解决这一挑战,SampleAttention 提出了:

  1. CRA 作为指导指标: CRA 量化了稀疏化后每个查询剩余注意力概率的最小值总和,被证明与模型准确性呈正相关。这使得 SampleAttention 能够设定一个 CRA 阈值,以此为目标动态选择最关键的注意力索引,从而在效率和准确性之间找到最佳平衡。
  2. 两阶段查询引导键值过滤 (Two-stage Query-Guided Key-Value Filtering):
    • 第一阶段:查询引导分块采样 (Query-Guided Chunked Sampling): 针对整个注意力分数矩阵,通过采样少量查询块来估计其稀疏结构。这种分块采样比仅采样末尾查询更稳定和准确,尤其在处理复杂混合模式时。

    • 第二阶段:基于得分的键值过滤 (Score-Based Key-Value Filtering): 基于第一阶段得到的采样注意力分数,独立地筛选出满足 CRA 阈值要求的关键列条带和斜杠条带。这种分解策略显著降低了组合选择的计算开销。

      通过这种方式,SampleAttention 能够在运行时自适应地确定稀疏比率和模式,生成一个几乎无损的稀疏掩码,然后利用这个掩码进行稀疏计算,从而加速长上下文 LLM 的预填充过程。

4.2. 核心方法详解

SampleAttention 的核心是一个两阶段的查询引导键值过滤算法,旨在高效地识别并利用注意力矩阵中的自适应结构化稀疏模式。以下是详细的步骤和相关公式解释。

4.2.1. 累积残差注意力 (Cumulative Residual Attention, CRA)

CRASampleAttention 中用来平衡效率和准确性的关键指标。它被定义为在稀疏化后,每个查询所剩余的注意力概率的最小值总和。

具体来说,假设 P\mathbf{P} 是原始的全注意力分数矩阵。如果通过一个掩码 M\mathbf{M} 得到稀疏注意力分数矩阵 P^\hat{\mathbf{P}},则 CRA 衡量的是 P^\hat{\mathbf{P}} 相对于 P\mathbf{P} 的“召回率”或“保留度”。

CRA(P^,P)=i=1Sqj=1SkP^iji=1Sqj=1SkPij \mathrm{CRA}(\hat{\mathbf{P}}, \mathbf{P}) = \frac{\sum_{i=1}^{S_q} \sum_{j=1}^{S_k} \hat{\mathbf{P}}_{ij}}{\sum_{i=1}^{S_q} \sum_{j=1}^{S_k} \mathbf{P}_{ij}}

然而,论文中对 CRA 的描述更侧重于其作为最小注意力概率总和的指标,即在给定一个稀疏模式后,有多少重要的注意力值被保留下来。更直观的理解是,它反映了稀疏化后的注意力分布与原始全注意力分布的相似性。如图 5 所示,CRA 阈值与模型准确性之间存在正相关性,意味着通过控制 CRA 阈值,可以间接控制模型准确性。

SampleAttention 的目标是找到一个最小的键值索引集合,使得在这个集合上计算的稀疏注意力能够达到预设的 CRA 阈值 α\alpha。这个 α\alpha 被进一步分解为列方向的阈值 αc\alpha_c 和斜杠方向的阈值 αs\alpha_s

4.2.2. 两阶段查询引导键值过滤

SampleAttention 的两阶段实现如算法 1 所示。

算法 1:SampleAttention 的两阶段实现 Input: Q, K, V, αc, αs ∈ [0, 1], chunkn # Stage1: Query-Guided Chunked Attention Sampling Qslice ← [Q[i*itv-blk:i*itv] for i in range(1, chunkn +1)] Å ← softmax (QsliceK T /√d + mcasual) Åc, Ås ← block_reduction (Å, blk) # Stage2: Score-Based Key-Value Block Filtering kc ← find_k (cumsum(sort(Åc), αc) Ic ← arg_topk (Åc, kc) ks ← find_k (cumsum(sort(Ås)), αs) Is ← arg_topk (Ås, ks) # Extend and Merge Block-sparse Mask across Each Head M̂ ← merge_index(Ic, Is, itv) # Final Sparse FlashAttention with Block Index O ← sparse_flash_attn (Q, K, V, M̂)

输入:

  • Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V}: 查询、键、值矩阵。
  • αc,αs[0,1]\alpha_c, \alpha_s \in [0, 1]: 分别是列和斜杠方向的 CRA 阈值。
  • chunkn: 采样分块 (chunk) 的数量。

4.2.2.1. 第一阶段:查询引导分块采样 (Query-Guided Chunked Attention Sampling)

这一阶段的目的是通过对注意力分数矩阵进行高效采样,来估计其整体的稀疏结构。

  1. 分块查询选择: Qslice [Q[iitvblk:iitv] for i in range(1,chunkn+1)]\leftarrow [\mathbf{Q}[i \cdot \text{itv} - \text{blk} : i \cdot \text{itv}] \text{ for } i \text{ in range}(1, \text{chunkn} + 1)]

    • itv (interval): 每个采样分块的间隔,通常是 Sq / chunkn
    • blk (block size): 采样块的实际大小。
    • Qslice: 这是一个包含 chunkn 个查询子矩阵的列表。SampleAttention 不是随机采样,也不是仅采样序列末尾的查询(如 MInference),而是将整个查询序列 Q\mathbf{Q} 分成 chunkn 个等长的段,并从每个段的末尾选择一个大小为 blk 的查询块进行采样。这种等距分块采样策略能够更全面地覆盖注意力矩阵,从而更准确地捕捉不同区域的稀疏模式,特别是那些在序列中间也可能出现的复杂混合模式(如图 4c 所示)。
  2. 注意力分数估计: Å softmax(QsliceKT/d+mcasual)\leftarrow \mathrm{softmax}(\mathbf{Qslice}\mathbf{K}^T / \sqrt{d} + \mathbf{mcasual})

    • Qslice 这里代表了所有采样查询块的堆叠。
    • QsliceKT/dQsliceK^T / \sqrt{d}: 对所有采样的查询块与所有键计算点积注意力分数,并进行缩放。
    • + mcaual: 添加因果掩码 (casual mask),确保每个查询只能关注其之前的键。
    • softmax: 对计算出的分数进行 softmax 操作,得到采样注意力分数矩阵 Å。这个矩阵包含了对全注意力分数矩阵的稀疏估计。
  3. 块级别分数规约 (block_reduction): Åc, Ås block_reduction(A˚,blk)\leftarrow \mathrm{block\_reduction}(\text{Å}, \text{blk})

    • Åc: 列方向的块级注意力分数。block_reduction 函数沿着列方向对 Å 进行规约(例如求和或平均),得到每个键块的聚合重要性分数。
    • Ås: 斜杠方向的块级注意力分数。block_reduction 函数沿着斜杠方向对 Å 进行规约,得到每个斜杠块的聚合重要性分数。
    • blk: 用于规约的块大小。
    • 这一步将词元级别 (token-level) 的采样分数转换为块级别 (block-level) 的分数,为第二阶段的过滤做准备,降低了后续处理的粒度。

4.2.2.2. 第二阶段:基于得分的键值过滤 (Score-Based Key-Value Block Filtering)

这一阶段的目标是基于第一阶段得到的块级别采样分数,并利用预设的 CRA 阈值,高效地选择出最关键的列和斜杠条带。

  1. 列条带过滤:

    • kc find_k(cumsum(sort(A˚c)),αc)\leftarrow \mathrm{find\_k}(\mathrm{cumsum}(\mathrm{sort}(\text{Åc})), \alpha_c)
      • sort(Åc): 将列方向的块级分数 Åc 进行排序。
      • cumsum(...): 计算排序后分数的累积和 (cumulative sum)。
      • findk(...,αc)find_k(..., αc): 找到最小的 kk 值(即 kc),使得这些 kk 个最高分数的累积和达到预设的列方向 CRA 阈值 αc\alpha_c。这实现了动态确定需要保留的列块数量。
    • Ic arg_topk(A˚c,kc)\leftarrow \mathrm{arg\_topk}(\text{Åc}, \text{kc})
      • arg_topk(Åc, kc): 找到在 Åc 中对应 kc 个最高分数的块索引 Ic。这些是需要保留的列条带的索引。
  2. 斜杠条带过滤:

    • ks find_k(cumsum(sort(A˚s)),αs)\leftarrow \mathrm{find\_k}(\mathrm{cumsum}(\mathrm{sort}(\text{Ås})), \alpha_s)
      • ks: 找到最小的 kk 值(即 ks),使得这些 kk 个最高斜杠分数的累积和达到预设的斜杠方向 CRA 阈值 αs\alpha_s
    • Is arg_topk(A˚s,ks)\leftarrow \mathrm{arg\_topk}(\text{Ås}, \text{ks})
      • Is: 找到在 Ås 中对应 ks 个最高分数的块索引 Is。这些是需要保留的斜杠条带的索引。

        通过将总 CRA 阈值 α\alpha 分解为 αc\alpha_cαs\alpha_s 并独立进行过滤,计算复杂度从 O(ncolumn×nslash)O(n_{column} \times n_{slash}) 降低到 O(ncolumn+nslash)O(n_{column} + n_{slash}),其中 ncolumnn_{column}nslashn_{slash} 分别是可能的列和斜杠条带数量。这大大提高了效率。

4.2.2.3. 构建稀疏掩码并进行稀疏注意力计算

  1. 扩展和合并稀疏掩码 (Extend and Merge Block-sparse Mask): M^\hat{\mathbf{M}} merge_index(Ic,Is,itv)\leftarrow \mathrm{merge\_index}(\text{Ic}, \text{Is}, \text{itv})

    • IcIs 是块级别索引。SampleAttention 会将这些索引扩展到整个注意力矩阵的词元级别,以形成完整的列条带和斜杠条带模式。
    • 如果来自不同采样分块的索引选择有所不同,merge_index 函数将它们合并成一个最终的块稀疏掩码 M^\hat{\mathbf{M}}。这个掩码涵盖了关键的注意力池 (attention sinks) 和局部窗口 (local window) 模式,确保了准确性。
  2. 最终稀疏 FlashAttention 计算 (Final Sparse FlashAttention): O\mathbf{O} sparse_flash_attn(Q,K,V,M^)\leftarrow \mathrm{sparse\_flash\_attn}(\mathbf{Q}, \mathbf{K}, \mathbf{V}, \hat{\mathbf{M}})

    • 使用生成的稀疏掩码 M^\hat{\mathbf{M}},在 FlashAttention2 的修改版内核中执行稀疏注意力计算,得到最终的注意力输出 O\mathbf{O}

4.2.3. 超参数调优 (Hyperparameter Tuning)

SampleAttention 引入了三个可调超参数,如 Table 1 所示,它们影响着效率和准确性之间的权衡:

Table 1. The meaning of hyperparameters and they will be tuned offine for different length ranges.

Hyperparameter Description
αcαc The desired CRA threshold for columns
αsαs The desired CRA threshold for slashes
chunkn The number of sampling chunks
  • 列和斜杠的阈值 (αcαc, αsαs):

    • 影响: 较大的阈值通常能提升模型准确性,但会降低加速比(因为需要保留更多的注意力元素)。较小的阈值则能带来更大的加速,但可能牺牲准确性。
    • 调优方法: 离线使用紧凑数据集预先确定具有成本效益的阈值。可以参考 FlashAttention2 的准确性和延迟作为基准。对于长上下文任务,可以根据长度范围对上下文进行分段,并为每个段单独调优阈值,以实现更优的稀疏性利用。
  • 采样分块数量 (chunkn):

    • 影响: chunkn 影响采样位置和选择索引的比率。

      • 过少的采样可能无法捕捉完整的稀疏结构,从而降低准确性。
      • 过多的采样会增加采样开销,并可能引入冗余计算,降低加速比。
    • 调优方法: 在调优过程中,尝试多个 chunkn 值以扩展搜索空间,找到更高效的配置。论文实验表明,适当增加 chunkn (例如从 1 到 2)可以在不显著影响加速比的情况下提高准确性。

      这些超参数的调优是自动化的,通过在小规模验证数据集上进行,使得 SampleAttention 能够快速适应不同的模型和任务,在不同长度范围内实现最佳的效率-准确性平衡。

4.2.4. 硬件高效实现 (Hardware-efficient Implementation)

为了实现显著的挂钟时间 (wall-clock time) 加速,SampleAttention 进行了 IO 感知的硬件优化:

  1. 操作符融合 (Operator Fusion): 查询引导键值过滤阶段涉及一系列小型操作(如 bmmmask_fillsoftmaxreduction),这些操作通常会读写大量的中间结果,导致 IO 开销。SampleAttention 通过融合这些操作符,显著减少了 IO 开销。

  2. 修改 FlashAttention2 内核 (Modified FlashAttention2 Kernel): SampleAttentionFlashAttention2 的基础上实现了高效的自适应结构化稀疏注意力内核。这意味着它利用了 FlashAttention2 已经实现的内存优化,并在此基础上增加了对稀疏模式的硬件高效支持。

    这些优化共同确保了 SampleAttention 不仅在理论上减少了计算量,而且在实际硬件上也能转化为显著的速度提升。

5. 实验设置

5.1. 数据集

为了全面评估 SampleAttention 的性能,研究在三个不同类型的长上下文任务基准测试上进行了实验:

5.1.1. RULER (Hsieh et al., 2024)

  • 特点: RULER 提供了一个灵活的配置平台,用于全面评估长上下文语言模型。它扩展了传统的“大海捞针” (needle-in-a-haystack) 测试,引入了多样化的“针”的类型和数量,并增加了多跳追踪和聚合等新任务,从而评估模型除简单检索之外更复杂的行为。
  • 规模: 包含 13 个任务。
  • 用途: 是测试长上下文理解能力的优秀工具。论文使用 RULER 生成小规模任务来调优 SampleAttention 的超参数,并在不同长度范围 (16K, 32K, 64K, 128K) 内进行评估。

5.1.2. LongBench (Bai et al., 2023)

  • 特点: 一个多任务基准测试,涵盖了单文档和多文档问答 (QA)、摘要 (summarization)、少样本学习 (few-shot learning)、合成任务 (synthetic tasks) 和代码补全 (code completion)。
  • 规模: 提供超过 4,750 个测试用例,任务长度范围从 4K35K
  • 用途: 用于评估 SampleAttention 在多种长上下文任务上的泛化能力和准确性。

5.1.3. InfiniteBench (Zhang et al., 2024b)

  • 特点: 专门设计用于评估语言模型在上下文长度超过 200K 的情境下处理、理解和推理的能力。
  • 规模: 包含 10 个独特的任务,每个任务都旨在评估在扩展上下文中的语言处理和理解的不同方面。
  • 用途: 用于测试 SampleAttention 在超长上下文情境下的性能极限。

5.2. 评估指标

论文主要通过以下指标来评估模型的性能:

5.2.1. 准确性 (Accuracy)

  • 概念定义: 衡量模型在特定任务上输出与真实标注数据 (Ground Truth) 匹配或接近的程度。在 LLM 任务中,准确性通常是根据任务类型(如问答的 F1 分数或精确匹配,摘要的 ROUGE 分数,代码补全的通过率等)而定的。论文中提到 SampleAttention 达到了全注意力模型 99% 以上的准确性,表明其在性能上近乎无损。
  • 数学公式: 具体的准确性公式取决于不同的任务。例如,对于分类任务: Accuracy=Number of Correct PredictionsTotal Number of Predictions \text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} 对于问答任务,常用的可能是精确匹配 (Exact Match, EM) 或 F1 分数。由于论文评估了多种任务,没有给出统一的公式,但通常指的是任务相关的标准评估指标。
  • 符号解释:
    • Number of Correct Predictions: 模型给出正确预测的数量。
    • Total Number of Predictions: 模型进行预测的总数量。

5.2.2. 首次生成词元时间 (Time-to-First-Token, TTFT) 延迟

  • 概念定义: TTFT 衡量的是用户发出请求到模型生成并输出第一个词元所需的时间。在交互式 LLM 应用中,TTFT 是一个关键的用户体验指标,因为它直接影响用户的感知延迟。它主要受预填充阶段的计算量影响。
  • 数学公式: 通常以毫秒 (ms) 或秒 (s) 为单位进行测量,没有复杂的公式,直接是时间差。 TTFT=TimeFirst Token OutputTimeRequest Received \text{TTFT} = \text{Time}_{\text{First Token Output}} - \text{Time}_{\text{Request Received}}
  • 符号解释:
    • TimeFirst Token Output\text{Time}_{\text{First Token Output}}: 模型输出第一个词元的时间点。
    • TimeRequest Received\text{Time}_{\text{Request Received}}: 系统接收到用户请求的时间点。

5.2.3. 加速比 (Speedup)

  • 概念定义: 衡量新方法相比于基线方法在执行时间上的提升倍数。加速比越高,说明新方法越高效。
  • 数学公式: Speedup=Execution Time of BaselineExecution Time of New Method \text{Speedup} = \frac{\text{Execution Time of Baseline}}{\text{Execution Time of New Method}}
  • 符号解释:
    • Execution Time of Baseline: 基线方法(例如 FlashAttention2)的执行时间。
    • Execution Time of New Method: 新方法(例如 SampleAttention)的执行时间。

5.3. 对比基线

为了评估 SampleAttention 的性能,论文将其与多种现有方法进行了比较:

  • 全注意力 (Full Attention): 作为黄金标准基线,代表了理论上最高的准确性(无稀疏化),但计算成本最高。论文使用 FlashAttention2 (Dao, 2023) 实现的全注意力作为性能基线,衡量 TTFT 加速比。
  • Minference (Jiang et al., 2024): 一种动态稀疏注意力方法,通过预分析确定注意力头的最优模式,并在运行时搜索稀疏索引。
  • BigBird (Zaheer et al., 2020): 一种静态稀疏注意力方法,结合了窗口注意力、全局注意力和随机注意力。论文中为其分配了 8% 的全局稀疏比率。
  • StreamingLLM (Xiao et al., 2023b): 一种用于长上下文的内存高效方法,通过保留注意力池 (attention sinks) 和最近的词元。其初始注意力池设置为 4 个词元。
  • HyperAttention (Han et al., 2023): 一种利用局部敏感哈希 (LSH) 识别重要注意力分数的动态稀疏方法。桶大小 (bucket size) 和采样列数均设置为 256。
  • Hash-Sparse (Pagliardini et al., 2023): 另一种利用哈希进行稀疏化的方法。

实验设置细节:

  • 所有实验均在单个 NVIDIA-A100 GPU (80GB) 上进行。
  • SampleAttention 仅在预填充阶段替换了全注意力实现,而解码阶段保留了未压缩的 KV 缓存和稠密注意力计算。
  • SampleAttention 的超参数通过在 RULER 基准测试的特定长度(16K, 32K, 64K, 128K)子集上进行调优,然后将这些参数应用于不同长度范围内的任务。

6. 实验结果与分析

6.1. 核心结果分析

本节详细分析 SampleAttention 在准确性和效率方面的实验结果,并与基线方法进行对比。

6.1.1. 准确性与效率的权衡

以下是原文 Figure 8 的图表,展示了 SampleAttention 在不同超参数下与各种稀疏方法在 RULER 基准测试上的准确性与相对加速(相对于 FlashAttention2)的权衡。

该图像是一个示意图,展示了不同模型在相对加速和评分上的表现。图中包含六个子图,分别针对 GLM 和其他模型(如 SampleAttention 和 FlashAttention2)进行了比较,量化了它们在处理任务中的速度与准确性之间的关系。

分析:

  • SampleAttention 的帕累托前沿: 图 8 清晰地表明,SampleAttention 在准确性-效率权衡方面建立了新的帕累托前沿。它能够在保持极高准确性的同时,实现显著的加速比。多个点显示 SampleAttention 的准确性接近甚至超过 FlashAttention2(作为全注意力的实现),同时提供了明显的加速。
  • 近乎无损的准确性: 如图中点所示,SampleAttention 能够始终实现接近全注意力 (Full Attention) 的准确性,这验证了其“近乎无损”的设计目标。
  • 显著的加速比: SampleAttention 相比 FlashAttention2 实现了显著的加速,一些配置下甚至超过 2×2\times
  • 基线方法的局限性:
    • Minference 在某些情况下准确性与全注意力相当,但未能提供任何加速(其点位于加速比为 1 的位置)。
    • BigBird 提供了稳定的加速,但通常伴随着明显的准确性下降。
    • StreamingLLMHyperAttention 则在所有任务中都表现出显著的性能下降,表明它们在预填充阶段无法有效捕获关键的 KV 元素。

6.1.2. 跨基准测试的准确性表现

以下是原文 Table 2 的结果,比较了不同稀疏方法在 LongBenchInfiniteBench 任务上的准确性。

Table 2. The accuracy of different sparse methods on LongBench and InfiniteBench tasks. The best results are marked with Bold, while the second best results are marked with an Underline.

Benchmark Baseline Task Type Total Score
Single-Doc QA Multi-Doc QA Summarization Few-shot Learning Synthetic Tasks Code Completion
LongBench Full Attention 213.12 174.35 109.69 273.87 231.49 121.52 1124.04
Ours 214.53 174.42 108.92 278.33 234.55 125.18 1135.93
Minference 212.14 173.37 110.02 274.45 231.87 124.37 1126.22
BigBird 207.57 146.45 95.64 272.17 161.60 117.38 1000.81
StreamingLLM 142.79 129.36 89.71 168.13 19.70 98.43 648.12
HyperAttention 125.74 119.08 88.05 206.35 32.69 86.35 658.26
En.Sum En.QA En.MC En.Dia Zh.QA Code.Debug Math. Retr.Find PassKey Retr.Number Retr.KV
InfiniteBench Full Attention 28.30 12.17 58.95 34.00 13.22 30.71 37.71 100 100 44.0
Ours 28.30 16.52 61.57 31.50 14.28 31.40 37.14 100 100 49.6
Minference 28.00 11.39 60.26 28.70 14.81 31.70 39.43 100 100 43.0

分析:

  • SampleAttention 的鲁棒性:LongBenchInfiniteBench 上,SampleAttention 的准确性表现始终优异,大部分任务甚至超越了全注意力,总分位居榜首。这表明 SampleAttention 的自适应稀疏模式选择能够比全注意力更好地关注关键信息或在某些情况下起到正则化的作用。
  • Minference 的表现: MinferenceLongBench 上的总分位居第二,但在 InfiniteBench 上的某些任务(如 En.QA, En.MC)表现略逊于 SampleAttention 和全注意力。
  • 其他方法的不足: BigBirdStreamingLLMHyperAttentionLongBench 上均表现出显著的准确性下降,尤其是 StreamingLLMHyperAttention 的分数远低于全注意力,这印证了它们在长序列预填充阶段难以捕获关键 KV 元素的局限性。

6.1.3. TTFT 加速基准测试

以下是原文 Figure 11 的图表,展示了 SampleAttention 在时间分配和 TTFT 加速方面的性能。

Figure 11. (a) The percentage of time spent on sampling and sparse computation in SampleAttention. (b) Comparison of the TTFT metric using FlashAttention2 as the baseline.

分析:

  • 采样开销分析 (图 11a):
    • 随着序列长度的增加,采样和稀疏计算所占的时间百分比(即采样开销)相对减少。这意味着 SampleAttention 在处理更长序列时,其采样开销变得不那么显著,从而能够带来更大的相对加速优势。
    • 在短序列场景下,采样和索引构建的计算开销可能导致性能增益不那么明显。例如,在 8K 序列长度下,SampleAttention 的延迟与 FlashAttention2 几乎相同。
  • TTFT 性能对比 (图 11b):
    • 图 11b 显示了 SampleAttentionFlashAttention2 相比的 TTFT 加速比。
    • 当序列长度达到 100 万 (1M) 时,SampleAttention 能够将 TTFT 显著降低 5.29×5.29\times。这证明了 SampleAttention 在处理超长上下文时的强大加速能力。
    • 即使在较低的序列长度(如 32K128K),SampleAttention 也实现了 1.24×1.24\times2.36×2.36\times 的加速。

6.2. 数据呈现 (表格)

6.2.1. LongBench 和 InfiniteBench 准确性对比 (Table 2)

在 6.1.2 节已呈现。

6.2.2. 超参数对准确性和加速比的影响 (Table 3)

以下是原文 Table 3 的结果,研究了不同 chunkn 值对分数和加速比的影响。

Table 3. The impact of changing chunkn on scores/speedup in different cases. The scores above are based on RULER, while the speedup below are relative to FlashAttention2. The best score results are highlighted in bold, while the best speedup results are marked with underline.

Model (αc,αs) | chunkn
1 2 4 6
GLM(128K) (0.90,0.90) 82.89/1.92 84.70/1.89 84.02/1.70 83.62/1.53
(0.95,0.95) 84.17/1.64 84.04/1.60 83.14/1.46 83.73/1.33
Yi(128K) (0.95,0.95) 52.81/2.17 54.54/2.12 53.10/1.89 54.58/1.71
(0.98,0.98) 56.24/1.29 58.58/1.25 59.37/1.21 59.36/1.13

分析:

  • chunkn 的影响: 适当增加 chunkn 可以提升准确性。例如,对于 GLM(128K)GLM(128K) 模型,在 (0.90,0.90)(0.90, 0.90) 的阈值下,chunkn=2chunkn=2 达到了最高的准确性 84.70,且加速比仍在较高水平。
  • 效率-准确性权衡: chunkn 过高(如 chunkn=4chunkn=46)可能导致加速比下降,因为采样开销增加。因此,选择合适的 chunkn 对于在准确性和加速比之间取得最佳平衡至关重要。

6.2.3. 稀疏性分析 (Table 4)

以下是原文 Table 4 的结果,展示了“大海捞针”任务中稀疏性随序列长度的变化。

Table 4. Sparsity analysis for the "Needle in a Haystack" task

Sequence Length Average Sparsity in ChatGLM-6B Average Sparsity in InternLM-7B
4K 88.00% 91.13%
8K 90.74% 92.72%
16K 92.52% 93.89%
32K 93.88% 94.83%
64K 94.89% 95.89%
128K 95.84% 96.67%

分析:

  • 稀疏性随长度增加: 随着序列长度的增加,ChatGLM-6BInternLM-7B 模型的平均稀疏比率都在增加。这表明在更长的上下文中,注意力矩阵中非关键的连接更多,从而为稀疏化提供了更大的潜力。
  • 模型差异: 不同模型的稀疏性水平略有不同,InternLM-7B 普遍比 ChatGLM-6B 具有更高的平均稀疏比率。

6.2.4. 采样有效性分析 (Table 5)

以下是原文 Table 5 的结果,展示了在不同采样率下选择 top-k 条带所能达到的 CRA 百分比。

Table 5. The CRA percentages can be achieved by selecting different ratios of top k stripes under varying sampling rates. The tests were conducted on the RULER task with a sequence length of 64K using the GLM4-9B model.

ratio of top-k 2.5% 0.4% 10% 0.4% 20% 0.4% 40% 0.4% 80% 0.4%
sampling ratio 100% 0.4% 100% 0.4% 100% 0.4% 100% 0.4% 100% 0.4%
HEAD-1 16.35% 12.74% 26.91% 23.33% 45.99% 42.14% 58.21% 55.34% 96.30% 93.65%
HEAD-2 55.43% 48.40% 63.89% 58.63% 85.92% 81.98% 89.07% 84.21% 99.15% 98.08%
HEAD-3 93.20% 90.44% 98.32% 97.62% 99.14% 98.43% 99.41% 99.12% 99.98% 99.66%

分析:

  • 低采样率的有效性: 即使在 0.4% 的低采样率下,通过选择 top-k 条带所获得的 CRA 值与 100% 全采样率下的 CRA 值非常接近。这表明 SampleAttention 中使用的分块采样方法是高效且准确的,能够用极低的开销近似全注意力的重要信息。
  • 差异随 top-k 比例减小: 随着 top-k 条带选择比例的增加,两种采样率下的 CRA 差异逐渐减小,这意味着采样方法在捕获大部分关键注意力信息方面表现良好。
  • 头部差异: 不同注意力头 (HEAD-1, HEAD-2, HEAD-3) 对稀疏化的敏感度不同,有些头即使在低 top-k 比例下也能保持很高的 CRA(如 HEAD-3),这验证了 SampleAttention 针对头部特异性稀疏性的设计需求。

6.2.5. 调优超参数 (Table 6)

以下是原文 Table 6 的结果,展示了针对不同长度范围和模型的调优超参数。

Table 6. T H highest speedups while retaining at least 90% of the accuracy.

range of sequence length < 16K [16K,48K) [48K,80K) [80K,112K) >=112K
GLM4 InternLM2 GLM4 InternLM2 GLM4 InternLM2 GLM4 InternLM2 GLM4 InternLM2
CRA Column 0.98 | 0.85 0.98 | 0.85 0.95 | 0.85 0.95 | 0.80 0.95 | 0.80 0.92 | 0.80 0.90 | 0.80 0.92 | 0.80 0.95 | 0.80 0.95 | 0.80
CRA Slash 0.90 | 0.85 0.95 | 0.85 0.95 | 0.80 0.90 | 0.80 0.95 | 0.80 0.92 | 0.80 0.90 | 0.80 0.90 | 0.80 0.85 | 0.80 0.90 | 0.80
Num of Chunks 111 111 111 111 111 111 211 211 111 211

分析:

  • 长度范围依赖性: 调优结果显示,超参数 (αcαc, αsαs, chunkn) 确实根据序列长度范围和模型而变化。例如,对于 GLM4 模型,在短序列(<16K<16K)时 CRA Column 阈值为 0.980.850.98 | 0.85,而在长序列(>=112K>=112K)时变为 0.950.800.95 | 0.80,这说明不同长度下稀疏化的策略是不同的。
  • 模型特异性: 即使在相同的长度范围内,GLM4InternLM2 的最优超参数也可能不同,这验证了论文中关于模型架构特异性稀疏比率的论断。
  • chunkn 的变化: chunkn 参数在某些长序列范围(如 [80K, 112K)InternLM2[48K, 80K)>=112K>=112K)下从 1 增加到 2,再次强调了其在捕获复杂稀疏模式中的重要性。

6.3. 消融实验/参数分析

6.3.1. CRA 阈值 (αcαc, αsαs) 的影响

以下是原文 Figure 9 的图表,展示了不同 αcαcαsαs 值对准确性的影响。

Figure 9. The heatmaps under different cases illustrate the impact of choosing different values of \(\\alpha _ { c }\) and \(\\alpha _ { s }\) on the accuracy of calculated blocks). The chun \(\\mathbf { \\nabla } _ { k _ { n } }\) values for the GLM and YI models are set to 2 and 4, respectively.

分析:

  • 阈值与准确性的关系: 图 9 的热力图直观地显示了 αcαcαsαs 阈值对准确性的影响。通常,增加任一阈值都会提高准确性,但同时也会增加计算负荷(减少稀疏性)。
  • 不同模型的敏感度:
    • 对于 GLM49B(32K)GLM4-9B (32K) (图 9a),准确性对 αcαcαsαs 的变化呈现出相对平衡的响应。
    • 对于 YI-9B (128K) (图 9c),模型对斜杠阈值 αsαs 的变化更为敏感,这可能意味着 YI 模型在长上下文中对局部信息捕获的需求更为强烈。
    • 图 9b (GLM49B(128K)GLM4-9B (128K)) 显示即使是较小的阈值也能提供足够的准确性,从而实现更高的加速比。
  • 权衡的重要性: 这些结果强调了对 αcαcαsαs 进行精细调优以找到最佳效率-准确性权衡的重要性。

6.3.2. 采样分块数量 (chunkn) 的影响

Table 3 中已分析了 chunkn 对准确性和加速比的影响。结论是:

  • 适当 chunkn 提升准确性: 增加 chunkn 从 1 到 2 可以在不显著牺牲加速比的情况下提高准确性,因为它能够更全面地捕获注意力模式。
  • 过大 chunkn 降低效率: 过大的 chunkn 值可能导致采样开销增加,从而降低整体加速比,即使准确性不再显著提升。因此,选择一个折衷的 chunkn 值至关重要。

6.3.3. 跨任务鲁棒性 (Cross-Task Robustness)

以下是原文 Figure 10 的图表,展示了离线调优的超参数在不同基准测试上的泛化能力。

Figure 10. Results from offline tuning and evaluation of (a) GLM4- 9B and (b) InternLM2-7B across RULER, LongBench, and InfiniteBench benchmarks. Different tasks share the same hyperparameters from offline tuning when sequence lengths fall within the same range.

分析:

  • 超参数的泛化性: 图 10 验证了 SampleAttention 的超参数在不同任务之间的鲁棒性。在 RULER 子集上调优的超参数,在 LongBenchInfiniteBench 上也表现出良好的泛化能力。
  • 近乎无损性能: GLM4 的“准确性优化”超参数配置在所有基准测试中都保持了近乎无损的性能,这证明了 SampleAttention 能够有效地适应不同任务而无需为每个任务单独调优。
  • 一致的加速比: 在相同模型和序列长度下,不同任务之间观察到一致的加速比,表明 SampleAttention 的效率提升是普遍性的。

6.3.4. 采样开销 (Sampling Overhead)

Figure 11(a) 中已分析了采样开销。关键发现是:

  • 随着序列长度的增加,采样开销占总时间的比例逐渐减少。这说明 SampleAttention 在处理超长序列时,其额外的采样步骤对整体性能的影响相对较小,从而能够实现更大的加速。

6.4. 注意力分数可视化

以下是原文 Figures 12 和 13 的可视化图,展示了 ChatGLM3-6B 模型中不同注意力头的稀疏模式。

该图像是示意图,展示了不同层(Layer0, Layer4, Layer8, Layer12)在序列长度变化下的注意力模式。每个子图呈现对应层的注意力矩阵,矩阵的颜色变化反映了注意力强度和稀疏性。 该图像是示意图,展示了不同层(Layer0, Layer4, Layer8, Layer12)在序列长度变化下的注意力模式。每个子图呈现对应层的注意力矩阵,矩阵的颜色变化反映了注意力强度和稀疏性。

该图像是示意图,展示了不同层级(Layer16、Layer20、Layer24)在序列长度为 k 时的注意力模式分布。各图分布了注意力权重,显示了层级特异性稀疏模式的动态变化。

分析:

  • 列条带和斜杠条带模式: 这些热力图直观地展示了注意力分数矩阵中普遍存在的列条带 (column stripe) 和斜杠条带 (slash stripe) 模式。
    • 列条带: 代表了对特定键(通常是早期词元或特殊词元,如注意力池)的全局关注,捕获全局上下文信息。例如,某些头在整个查询序列上都高度关注某个或某几个键,形成垂直的亮线。
    • 斜杠条带: 代表了对最近词元或局部窗口内词元的关注,捕获局部上下文信息。它们通常表现为对角线或接近对角线的亮带。
  • 头部特异性: 不同的注意力头展示出截然不同的稀疏模式。有些头主要显示列模式,有些主要显示斜杠模式,还有些头则显示这两种模式的复杂组合。这强化了论文中关于头部特异性稀疏模式的观察,并证明了 SampleAttention 动态捕获这些多样模式的必要性。
  • 支撑 SampleAttention 设计: 这些可视化结果直接支撑了 SampleAttention 的设计,即通过独立且自适应地捕获列和斜杠模式,能够有效地近似全注意力,因为这些模式构成了注意力分数矩阵中最重要的部分。

7. 总结与思考

7.1. 结论总结

本文深入探讨了长上下文大型语言模型 (LLMs) 在推理阶段面临的 Time-to-First-Token (TTFT) 延迟问题,其根源在于传统注意力机制的二次复杂度以及现有稀疏注意力方法在适应动态稀疏性方面的局限。

为了解决这些挑战,论文提出了 SampleAttention,一种自适应结构化且近乎无损的稀疏注意力方法。其核心创新点包括:

  1. 识别动态稀疏性: 强调并实证了注意力稀疏比率和模式在不同注意力头、输入内容和模型架构之间的高度自适应性。

  2. 引入 CRA 指标: 提出了 Cumulative Residual Attention (CRA) 作为衡量模型准确性的稳健指标,并利用其指导稀疏化过程中的效率-准确性权衡。

  3. 两阶段查询引导键值过滤: 设计了创新的两阶段算法,通过“查询引导分块采样”高效估计稀疏结构,并利用“基于得分的键值过滤”自适应地选择关键的列条带和斜杠条带。这种方法在运行时以低开销动态确定稀疏比率和模式。

  4. 自动化调优与硬件优化: 提供自动化超参数调优方法,并结合 IO 感知的操作符融合和优化的 FlashAttention2 内核,确保了在实际硬件上的显著加速。

    综合实验结果表明,SampleAttention 能够在 ChatGLMYIInternLM 等主流 LLMs 上实现近乎无损的准确性(通常保持 99% 以上),并在长上下文场景下将 TTFT 延迟降低高达 5.29×5.29\times,显著超越了现有基线方法,在效率-准确性权衡上建立了新的帕累托前沿。

7.2. 局限性与未来工作

论文中并未明确列出“局限性”部分,但从其描述和上下文可以推断出一些潜在的局限性和未来工作方向:

潜在局限性:

  • 离线调优的成本: SampleAttention 虽然提供了自动化调优方法,但超参数 αcαc, αsαs, chunkn 仍然需要通过在小规模数据集上进行离线调优来确定。虽然这比每次推理都进行调优要高效,但对于全新的模型或特定任务,仍然存在一定的预处理成本和调优时间。
  • 采样开销在短序列中的影响: 论文指出,在短序列(如 8K)中,SampleAttention 的延迟与 FlashAttention2 几乎相同,因为动态采样和索引构建的计算开销可能抵消了稀疏化带来的收益。这意味着 SampleAttention 的主要优势在于长上下文
  • 仅优化预填充阶段: SampleAttention 专注于优化预填充阶段的计算开销,而解码阶段仍保留未压缩的 KV 缓存和稠密注意力计算。虽然论文提到可以与 KV 缓存压缩方法结合,但这表明 SampleAttention 本身并非端到端 (end-to-end) 的全面优化方案。
  • 对特定模式的依赖: SampleAttention 的设计是基于注意力矩阵中普遍存在的列条带和斜杠条带模式。如果遇到完全不同且不规则的稀疏模式,其性能可能会受到影响。
  • 硬件实现复杂性: 尽管论文强调了硬件高效实现,但这意味着其自定义内核的开发和维护可能比直接使用标准库更复杂,对不同硬件平台的支持也可能需要额外的工作。

未来工作:

  • 结合 KV 缓存压缩: 论文明确指出 SampleAttention 可以与 KV 缓存驱逐或压缩方法(如 H2OSnapKV 等)结合,以进一步减少解码阶段的内存消耗。这将是实现更全面优化的重要方向。
  • 更细粒度的自适应: 探索更细粒度的自适应机制,例如不仅在头部层面,而是在更小的注意力块或查询层面动态调整稀疏策略。
  • 在线调优或学习机制: 研发能够在线学习或自适应调整超参数的方法,以消除对离线调优的需求,使方法更具通用性和即插即用性。
  • 支持更多稀疏模式: 探索除了列条带和斜杠条带之外的其他重要稀疏模式,以进一步提高近似准确性和泛化能力。
  • 端到端优化:SampleAttention 的稀疏化理念扩展到整个 LLM 推理流程,包括解码阶段的计算和内存优化。

7.3. 个人启发与批判

7.3.1. 个人启发

SampleAttention 的工作提供了几个重要的启发:

  1. 深度挖掘固有稀疏性: 论文再次强调了 LLMs 注意力机制中固有的高稀疏性,并成功地将其转化为性能优化的机会。这表明在 LLMs 领域,理解模型内部工作机制并利用其特性进行系统级优化,比单纯依赖通用优化技术更有效。
  2. 动态性是关键: 强调注意力稀疏比率和模式的动态性自适应性是区分优秀稀疏注意力方法和一般方法的关键。固定模式或粗粒度方法往往难以达到“近乎无损”的性能,因为它们无法捕捉到 LLM 处理不同上下文时的细微变化。
  3. 度量指标的重要性: CRA 的引入为稀疏化提供了一个可量化的指导原则,这对于平衡效率和准确性至关重要。一个好的度量指标能够使复杂的工程决策变得有据可循。
  4. 工程与算法的结合: SampleAttention 不仅提出了算法上的创新(两阶段过滤),还在硬件实现层面进行了优化(操作符融合、修改 FlashAttention2 内核),这体现了在高性能计算领域,算法设计与底层硬件优化的紧密结合是实现突破性加速的关键。
  5. 长上下文的未来: 随着 LLMs 上下文窗口的不断扩大,这类针对长上下文效率的优化工作将变得越来越重要,是推动 LLMs 走向更广泛实际应用的基础。

7.3.2. 批判性思考

尽管 SampleAttention 取得了令人印象深刻的成果,但仍有一些可以批判性思考的地方:

  1. CRA 的普适性: 虽然 CRA 被证明与模型准确性具有一致的正相关性,但这种相关性是否在所有 LLM 架构、所有任务类型上都同样稳健,仍需进一步验证。此外,CRA 阈值的选择本身就是一项工程权衡,它间接影响最终的用户体验,如何自动化或更智能地选择 CRA 阈值是一个挑战。

  2. “近乎无损”的定义: 论文宣称“近乎无损”,并在实验中展示了 99% 以上的准确性。然而,在某些对准确性极度敏感的下游任务中,即使是 1% 的准确性损失也可能是不可接受的。例如,在金融分析或医疗诊断等领域,微小的错误都可能导致严重后果。如何量化和管理这种“损失”的风险,仍然是一个值得探讨的问题。

  3. 计算开销的绝对值: 论文主要关注的是相对加速比采样开销的相对减少。虽然在 1M 序列长度下实现了 5.29×5.29\times 的加速,但 1M 序列的基线时间可能非常长(如 1555 秒),所以即使加速,绝对时间可能仍然较长。如何将 TTFT 降低到毫秒级别以支持真正的实时交互,仍需更多努力。

  4. 动态稀疏模式的复杂性: 列条带和斜杠条带是两种主要的稀疏模式,但注意力模式可能远比这两种简单组合更复杂。SampleAttention 的两阶段过滤是否能捕获所有关键的、非线性的、更复杂的稀疏结构?是否存在一些边缘情况或特定任务,其注意力模式不符合这种结构化假设?

  5. FlashAttention2 的依赖: SampleAttention 建立在 FlashAttention2 的基础上。如果未来的注意力计算基线发生重大变化,SampleAttention 可能需要进行相应的适配或重构。

  6. 可解释性: 虽然 SampleAttention 提高了效率,但稀疏掩码的动态生成过程可能会略微降低注意力机制的可解释性。理解为什么某些条带被选中,而另一些被忽略,可能需要更深入的分析工具。

    总而言之,SampleAttention 是一项非常有价值的工作,它在利用 LLMs 固有稀疏性以解决长上下文推理效率方面取得了显著进展。其方法论的严谨性和实验结果的强大性,为未来的 LLM 优化方向提供了宝贵的思路。

相似论文推荐

基于向量语义检索推荐的相关论文。

暂时没有找到相似论文。