MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention
TL;DR 精炼摘要
本文提出了MInference(百万标记推理),一种新颖的动态稀疏计算方法,旨在加速长上下文大型语言模型(LLM)的预填充阶段。通过识别长上下文注意力矩阵中的三种独特模式(A形、垂直斜杠和块稀疏),并在推理时动态构建稀疏索引,MInference显著降低了延迟,提高了效率与准确性,无需修改预训练设置或额外微调。
摘要
The computational challenges of Large Language Model (LLM) inference remain a significant barrier to their widespread deployment, especially as prompt lengths continue to increase. Due to the quadratic complexity of the attention computation, it takes 30 minutes for an 8B LLM to process a prompt of 1M tokens (i.e., the pre-filling stage) on a single A100 GPU. Existing methods for speeding up prefilling often fail to maintain acceptable accuracy or efficiency when applied to long-context LLMs. To address this gap, we introduce MInference (Milliontokens Inference), a sparse calculation method designed to accelerate pre-filling of long-sequence processing. Specifically, we identify three unique patterns in long-context attention matrices-the A-shape, Vertical-Slash, and Block-Sparsethat can be leveraged for efficient sparse computation on GPUs. We determine the optimal pattern for each attention head offline and dynamically build sparse indices based on the assigned pattern during inference. With the pattern and sparse indices, we perform efficient sparse attention calculations via our optimized GPU kernels to significantly reduce the latency in the pre-filling stage of long-context LLMs. Our proposed technique can be directly applied to existing LLMs without any modifications to the pre-training setup or additional fine-tuning. By evaluating on a wide range of downstream tasks, including InfiniteBench, RULER, PG-19, and Needle In A Haystack, and models including LLaMA-3-1M, GLM4-1M, Yi-200K, Phi-3-128K, and Qwen2-128K, we demonstrate that MInference effectively reduces inference latency by up to 10x for pre-filling on an A100, while maintaining accuracy. Our code is available at https://aka.ms/MInference.
思维导图
论文精读
中文精读
1. 论文基本信息
1.1. 标题
MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention
1.2. 作者
Huiqiang Jiang, Yucheng Li, Chengruidong Zhang, Qianhui Wu, Xufang Luo, Surin Ahn, Zhenhua Han, Amir H. Abdi, Dongsheng Li, Chin-Yew Lin, Yuqing Yang, Lili Qiu
- 隶属机构 (Affiliations): Microsoft Corporation, University of Surrey
1.3. 发表期刊/会议
- 类型: 预印本 (ArXiv preprint)
- 链接:
- 发布状态: 预印本,尚未正式发表。预印本平台如 ArXiv 允许研究者在同行评审 (peer review) 完成前分享其研究成果,但其内容未经正式的学术出版流程验证。
1.4. 发表年份
2024年7月2日 (UTC)
1.5. 摘要
大型语言模型 (Large Language Model, LLM) 推理的计算挑战仍然是其广泛部署的重要障碍,尤其随着提示词 (prompt) 长度的不断增加。由于注意力 (attention) 计算的二次方复杂度 (quadratic complexity),对于一个 8B (80亿参数) 的 LLM 来说,在单块 A100 GPU 上处理一个 1M (百万) 词元 (token) 的提示词(即预填充阶段 (pre-filling stage))需要 30 分钟。现有的预填充加速方法在应用于长上下文 (long-context) LLM 时,往往无法保持可接受的准确性 (accuracy) 或效率 (efficiency)。为了解决这一问题,本文引入了 MInference (Milliontokens Inference),这是一种稀疏计算 (sparse calculation) 方法,旨在加速长序列 (long-sequence) 处理的预填充阶段。具体来说,研究人员识别出长上下文注意力矩阵 (attention matrix) 中的三种独特模式:A形 (A-shape)、垂直斜杠 (Vertical-Slash) 和块稀疏 (Block-Sparse),这些模式可以用于 GPU 上的高效稀疏计算 (efficient sparse computation)。研究人员在离线 (offline) 阶段为每个注意力头 (attention head) 确定最佳模式,并在推理 (inference) 期间根据分配的模式动态 (dynamically) 构建稀疏索引 (sparse indices)。通过这些模式和稀疏索引,以及研究人员优化的 GPU 内核 (kernel),MInference 能够显著降低长上下文 LLM 预填充阶段的延迟 (latency)。本文提出的技术可以直接应用于现有 LLM,无需修改预训练设置 (pre-training setup) 或额外微调 (fine-tuning)。通过在 InfiniteBench、RULER、PG-19 和 Needle In A Haystack 等一系列下游任务,以及 LLaMA-3-1M、GLM4-1M、Yi-200K、Phi-3-128K 和 Qwen2-128K 等模型上进行评估,研究人员证明 MInference 在 A100 上可以将预填充延迟最多降低 10 倍,同时保持准确性。
1.6. 原文链接
https://arxiv.org/abs/2407.02490
1.7. PDF 链接
https://arxiv.org/pdf/2407.02490v2.pdf
2. 整体概括
2.1. 研究背景与动机
- 论文试图解决的核心问题: 大语言模型 (LLM) 在处理长上下文 (long-context) 时,由于注意力机制的二次方复杂度,导致预填充阶段 (pre-filling stage) 的计算成本极高,从而产生无法接受的延迟 (latency)。例如,处理 1M (百万) 词元 (token) 的提示词 (prompt) 可能需要长达 30 分钟,严重影响用户体验和 LLM 的实际部署。
- 为什么这个问题在当前领域是重要的: 随着 LLM 在各种复杂应用中(如代码理解、长文档问答、自玩推理等)被广泛采用,处理更长的上下文是解锁其全部潜力的关键。然而,当前的计算瓶颈阻碍了这些长上下文 LLM 的大规模应用。现有的一些加速预填充的方法,如固定稀疏注意力 (fixed sparse attention),往往在长上下文场景下无法保持足够的准确性 (accuracy) 和效率;而另一些动态稀疏注意力 (dynamic sparse attention) 方法则引入了过高的估计开销 (estimation overhead)。
- 这篇论文的切入点或创新思路是什么: 本文的创新点在于深入分析了长上下文 LLM 中注意力矩阵的稀疏性,并发现其并非随机分布,而是呈现出三种可被高效利用的独特空间模式:
A-shape(A形)、Vertical-Slash(垂直斜杠) 和Block-Sparse(块稀疏)。基于这些模式,论文提出了一种结合离线 (offline) 模式识别和在线 (online) 动态索引构建的稀疏计算框架,并通过优化的 GPU 内核 (kernel) 实现高效推理,旨在解决长上下文预填充的计算瓶颈,同时避免现有方法引入的准确性损失或过高开销。
2.2. 核心贡献/主要发现
- 提出了 MInference (Milliontokens Inference) 框架: 一个专门为加速长上下文 LLM 预填充阶段设计的动态稀疏注意力 (dynamic sparse attention) 方法。
- 识别出三种关键稀疏模式: 在长上下文注意力矩阵中发现了三种独特的、可用于高效稀疏计算的模式:A形、垂直斜杠和块稀疏。
- 开发核感知最优稀疏模式搜索 (Kernel-Aware Optimal Sparse Pattern Search): 提出了一种离线方法,能够为每个注意力头 (attention head) 自动识别并分配最合适的稀疏模式及其最佳设置,以在保持准确性的前提下最大限度地减少浮点运算 (FLOPs)。
- 实现动态稀疏索引构建: 在推理时,根据分配的模式和具体的输入动态地构建稀疏注意力掩码 (sparse attention mask),从而适应注意力模式的动态变化。
- 优化 GPU 内核: 为识别出的三种稀疏模式开发了高度优化的 GPU 内核,用于高效执行稀疏注意力计算,显著降低了计算时间。
- 无需模型修改或微调: 该方法可作为即插即用 (plug-and-play) 的解决方案直接应用于现有预训练 LLM,无需进行额外的训练或微调 (fine-tuning),极大地降低了部署门槛。
- 显著的性能提升和准确性保持:
- 在单块 A100 GPU 上,MInference 将 1M 词元提示词的预填充延迟从 30 分钟减少到 3 分钟,实现高达 10 倍的加速。
- 在 InfiniteBench、RULER、PG-19 和 Needle In A Haystack 等一系列长上下文基准测试上,MInference 在保持或甚至略微提高准确性的同时,实现了显著的加速。
3. 预备知识与相关工作
3.1. 基础概念
- 大型语言模型 (Large Language Model, LLM): 指的是具有庞大参数量、通过大规模文本数据训练的深度学习模型,通常基于 Transformer (变换器) 架构。LLM 能够理解、生成和处理人类语言,并在各种自然语言处理 (Natural Language Processing, NLP) 任务中表现出色,如问答、摘要、翻译等。
- 上下文窗口 (Context Window): 指 LLM 在生成下一个词元 (token) 时能够“看到”和利用的输入序列的长度。长上下文 LLM 意味着模型可以处理非常长的输入序列(例如数十万甚至数百万词元),这对于处理长文档、代码库或长时间对话等复杂任务至关重要。
- 预填充阶段 (Pre-filling Stage): LLM 推理过程通常分为两个主要阶段:
- 预填充 (Pre-filling): 当接收到一个新的用户提示词 (prompt) 时,模型会一次性处理整个输入序列以生成第一个输出词元,并构建好 Key-Value (KV) 缓存 (cache)。这个阶段的计算量与输入序列长度的平方成正比。
- 解码 (Decoding): 在生成第一个词元之后,模型会逐个生成后续词元。每个新生成的词元都会利用之前词元的 KV 缓存。这个阶段的计算量通常与生成词元的数量和 KV 缓存的大小成正比。
- 本文主要关注预填充阶段的优化,因为其二次方复杂度是长上下文 LLM 的主要瓶颈。
- 注意力机制 (Attention Mechanism): Transformer 架构的核心组件。它允许模型在处理序列中的每个词元时,动态地权衡输入序列中所有其他词元的重要性。注意力机制使得模型能够捕捉长距离依赖关系。
- 二次方复杂度: 标准的自注意力 (self-attention) 机制计算复杂度为 ,其中 是序列长度。这意味着随着序列长度的增加,计算成本会呈指数级增长,成为长上下文 LLM 的主要瓶颈。
- 注意力计算公式: 注意力 (Attention) 机制的计算通常涉及三个输入:查询 (Query, ) 矩阵、键 (Key, ) 矩阵和值 (Value, ) 矩阵。
符号解释:
- :查询矩阵,由输入序列中每个词元的查询向量组成,其中 是序列长度, 是键和查询向量的维度。
- :键矩阵,由输入序列中每个词元的键向量组成。
- :值矩阵,由输入序列中每个词元的值向量组成,其中 是值向量的维度。
- : 键矩阵的转置。
- : 查询和键的点积,表示查询与每个键之间的相似度得分。
- : 缩放因子,用于防止点积结果过大,导致
softmax函数梯度消失。 - :
softmax函数,将得分转换为概率分布,确保注意力权重之和为 1。 - : 值矩阵,注意力权重会乘以值矩阵,从而聚合来自不同位置的信息。
- 稀疏注意力 (Sparse Attention): 针对标准注意力机制二次方复杂度问题的一种优化方法。它通过限制每个词元只关注输入序列中的一部分词元,从而减少 矩阵中的非零元素数量,降低计算复杂度。稀疏注意力可以是固定的(预定义模式)或动态的(根据输入内容变化)。
- GPU 内核 (GPU Kernel): 是在图形处理器 (Graphics Processing Unit, GPU) 上执行的特定函数或程序。GPU 善于并行处理大量数据,因此优化 GPU 内核对于加速深度学习模型的计算至关重要。本文通过优化 GPU 内核来高效执行稀疏注意力计算。
- 词元 (Token): 在自然语言处理中,
token是文本被分割成的最小有意义的单位。它可以是一个单词、一个子词、一个字符或一个标点符号。LLM 处理文本时,首先会将文本转换为词元序列。 - 浮点运算 (Floating Point Operations, FLOPs): 是衡量计算复杂度的指标,表示执行一个浮点运算的次数。在深度学习中,FLOPs 通常用来估算模型的计算量。
3.2. 前人工作
论文在 “Related Works” 部分对大量相关工作进行了综述,主要分为以下几类:
- 稀疏注意力 (Sparse Attention):
- 静态稀疏模式 (Static sparse patterns): 这些方法预定义了注意力模式,如滑动窗口 (sliding windows) [JSM+23, AJA+24]、扩张注意力 (dilated attention) [CGRS19, SGR+21, DMD+23] 和混合稀疏模式 (mixed sparse patterns) [BPC20, ZGD+20, LCSR21]。典型的模型包括
Longformer[BPC20] 和BigBird[ZGD+20]。这些方法通常需要从头开始预训练模型,无法直接作为插件应用于现有 LLM。 - 基于聚类的稀疏方法 (Cluster-based sparse methods): 包括基于哈希 (hash-based) [KKL20] 和基于 kNN (kNN-based) [RSVG21, NEC+24] 的方法。这些也通常需要重新训练。
- 动态稀疏注意力 (Dynamic sparse attention): 一些工作 [WZH21, LOC+22, RCHG+24] 利用注意力机制的动态性来预测稀疏模式。然而,这些方法通常在动态模式近似过程中关注低秩隐藏状态 (low-rank hidden dimensions) 或使用后统计 (post-statistical) 方法获取稀疏掩码,导致估计步骤引入大量开销,对于长上下文 LLM 的实用性较低。MInference 旨在解决这一开销问题。
- 静态稀疏模式 (Static sparse patterns): 这些方法预定义了注意力模式,如滑动窗口 (sliding windows) [JSM+23, AJA+24]、扩张注意力 (dilated attention) [CGRS19, SGR+21, DMD+23] 和混合稀疏模式 (mixed sparse patterns) [BPC20, ZGD+20, LCSR21]。典型的模型包括
- 扩展 LLM 上下文窗口 (Scaling Context Windows of LLMs):
- 分阶段预训练 (Staged pre-training): [NXH+23, FPN+24] 通过在更长序列上继续训练来扩展上下文。
- 修改或插值位置编码 (Modifying or interpolating position embeddings): [PSL22, CWCT23, PQFS24, DZZ+24] 通过调整位置编码来使其能够处理更长的序列。
- 利用外部内存模块 (Utilizing external memory modules): [BANG23, TSP+23, XZH+24] 引入额外的内存单元来存储和检索上下文信息。
- 分布式计算 (Distributed manner): [LZA24] 通过在多个设备上分布计算来处理长上下文。
- 局限: 这些方法通常不直接解决长上下文处理中高昂的推理成本,而是侧重于提升模型的长文本处理能力。
- 长上下文 LLM 推理优化 (Long-Context LLM Inference):
- 预填充优化 (Pre-filling optimizations):
- 状态空间模型 (State Space Models) [GGR22, GD23] 和线性注意力 (linear attention) [SDH+23, PAA+23]:通常需要从头训练。
- 基于内存的方法 (Memory-based methods) [MFG24, HBK+24] 和混合方法 (Hybrid methods) [LLB+24, RLL+24]。
- 提示词压缩 (Prompt compression) [LDGL23, JWL+23, JW+24, PWJ+24]:通过压缩输入提示词来减少长度。
- 最近,一些研究 [MEL24, XZH+24, LCE+24] 专注于使用 kNN 或基于聚类的稀疏注意力来加速 LLM 推理,但往往导致准确性下降、加速有限或仅限于 CPU 场景。
- 解码阶段优化 (Decoding stage optimizations):
- KV 缓存重用 (Reusing attention KV) [Sha19, ALTdJ+23, SDZ+24, DA24, NEC+24]。
- 静态/动态 KV 缓存丢弃 (Static/Dynamic KV cache dropping) [XTC+24, HWP+24, ZSZ+24, LDL+24, GZL+24, OHAS24, LHY+24, APB+24]。
- 动态 KV 缓存卸载 (Dynamic KV cache offloading) [RCHG+24, DHJ+24, TZZ+24, LCL+24, CSY+24, SCB+24]。
- 恢复性能损失的方法 (Methods for restoring performance loss) [AAJ+24, DYZ+24]。
- 分层推测解码 (Hierarchical speculative decoding) [SCY+24, CTS+24]。
- KV 缓存量化 (KV cache quantitation) [LYJ+24]。
- 局限: 这些方法主要针对解码阶段,不解决预填充阶段的繁重计算负担。
- 预填充优化 (Pre-filling optimizations):
3.3. 技术演进
LLM 的注意力机制效率问题自 Transformer 模型诞生以来一直受到关注。技术演进大致遵循以下路径:
-
全注意力 (Full Attention): 最初的 Transformer 模型,简单直接,但具有二次方复杂度。
-
固定稀疏注意力 (Fixed Sparse Attention): 为了降低计算成本,研究者开始探索预设的稀疏模式,如局部窗口 (local windows) (例如
Longformer) 和扩张注意力 (dilated attention) (例如BigBird)。这些方法减少了计算量,但其固定模式可能无法适应不同输入的注意力分布,从而牺牲准确性。 -
动态稀疏注意力 (Dynamic Sparse Attention) (早期): 认识到注意力模式的动态性,一些工作尝试根据输入动态地选择或预测稀疏模式。然而,早期方法往往需要引入额外的网络来预测稀疏性,这本身就带来了显著的计算开销。
-
长上下文 LLM 的特定优化 (Specific Optimizations for Long-Context LLMs): 随着 LLM 能够处理更长的上下文,预填充阶段的二次方复杂度问题变得尤为突出。这一阶段的优化需要新的思路,既要利用稀疏性,又要避免大的估计开销,并且能够直接应用于已预训练的模型。
MInference 正是处于这一演进路径的最新阶段,它通过识别长上下文 LLM 中注意力矩阵的特定空间稀疏模式,并结合高效的离线模式搜索和在线动态索引构建,实现了在不牺牲准确性的前提下的显著加速,且无需对模型进行再训练。
3.4. 差异化分析
MInference 与相关工作中的主要方法相比,核心区别和创新点在于:
- 针对长上下文预填充阶段的专门优化: 现有许多方法要么侧重于解码阶段,要么是通用的稀疏注意力,未能充分利用长上下文预填充阶段注意力模式的独特特征。MInference 精准聚焦于这一瓶颈。
- 识别并利用三种独特的稀疏模式: 区别于仅使用单一固定模式 (如
StreamingLLM的 A形模式) 或通用动态稀疏方法,MInference 深入分析了长上下文注意力矩阵,并识别出 A形、垂直斜杠和块稀疏这三种具有不同空间分布特征的模式,从而更全面地捕捉稀疏性。 - 核感知最优模式搜索 (Kernel-Aware Optimal Pattern Search): 这种离线搜索方法不仅考虑了理论上的 FLOPs,更重要的是,它结合了实际 GPU 内核的效率,为每个注意力头分配了最适合其计算特性的稀疏模式和参数。这确保了在保持准确性的前提下,实现实际的、高效的加速。
- 低开销的动态索引构建: 相比于一些动态稀疏方法通过引入复杂的辅助网络来预测稀疏模式,MInference 采用轻量级的在线近似方法(如均值池化 (MeanPooling) 和对最后几个词元的注意力计算)来动态构建稀疏掩码,显著降低了估计开销,使其在实际推理中更具实用性。
- 无需模型训练或微调: 大多数传统的稀疏注意力方法或上下文扩展方法需要从头预训练或对模型进行微调,这对于已部署的大型 LLM 来说成本高昂。MInference 作为一种即插即用 (plug-and-play) 的解决方案,可以直接应用于现有的、未经修改的 LLM,极大地降低了集成和部署的门槛。
- 优化的 GPU 内核: 为每种识别出的稀疏模式开发了专门优化的 GPU 内核,确保了稀疏计算能够高效地在硬件上执行,而非仅仅是理论上的 FLOPs 减少。
4. 方法论
4.1. 方法原理
MInference 的核心原理在于,大型语言模型 (LLM) 在处理超长上下文 (long-context) 时,其自注意力 (self-attention) 矩阵并非密集计算,而是表现出高度的稀疏性,且这种稀疏性呈现出可被有效利用的特定空间模式。论文基于对这些模式的深入分析,提出了一种动态稀疏注意力 (dynamic sparse attention) 框架,旨在加速预填充阶段 (pre-filling stage) 的计算。
其直觉在于:
-
注意力稀疏性: 并非每个查询 (query) 都需要关注序列中的每个键 (key)。很多注意力权重 (attention weights) 在
softmax后趋近于零,对最终输出贡献微乎其微。 -
模式化稀疏性: 这种稀疏性并非随机的。通过观察发现,注意力权重往往集中在某些结构化的区域,这些区域可以被抽象为几种可识别的模式(A形、垂直斜杠、块稀疏)。
-
动态性: 尽管存在模式,但具体哪些区域是重要的,以及这些区域在不同输入下如何分布,是动态变化的。因此,需要一个能够在线 (online) 适应输入变化的机制。
-
硬件效率: 仅仅减少理论上的浮点运算 (FLOPs) 不够,还需要将稀疏计算映射到高效的 GPU 内核 (kernel) 上,以实现实际的加速。
MInference 的方法通过离线分析确定每个注意力头最适合的稀疏模式及其参数,然后在线近似构建动态稀疏掩码 (dynamic sparse mask),最后利用优化的 GPU 内核执行这些稀疏计算,从而在保持模型准确性 (accuracy) 的同时,显著降低预填充阶段的延迟。
4.2. 核心方法详解
MInference 框架主要包含三个步骤:1) 离线注意力模式识别;2) 动态稀疏索引构建;3) 稀疏注意力计算。在详细展开这三个步骤之前,我们首先形式化稀疏注意力计算的问题。
4.2.1. 问题形式化
在加速长上下文 LLM 预填充阶段时,稀疏注意力计算可以形式化为:
符号解释:
-
A(M): 经过稀疏掩码 作用后的注意力矩阵。 -
:
softmax激活函数,将注意力得分转换为概率分布。 -
: 查询 (Query) 矩阵,其中 是序列长度 (sequence length), 是注意力头维度 (head dimension)。
-
: 键 (Key) 矩阵。
-
: 键矩阵的转置。
-
: 缩放因子,用于防止点积结果过大。
-
: 一个非常大的常数,例如
1e5。 -
: 动态稀疏掩码 (dynamic sparse mask),表示注意力矩阵中位置
(i, j)是否被计算。如果 ,则 会变为一个非常大的正数,导致 在该位置的值变得非常小,经过softmax后 趋近于零。如果 ,则 为 0,不影响原始注意力计算。动态稀疏注意力系统的目标是:在保留尽可能多的注意力权重 (attention weights) 的同时,实现更大的加速和最小的开销。这可以形式化为优化以下两个目标:
符号解释:
- : 稀疏注意力矩阵
A(M)与密集注意力矩阵 之间的差异的绝对值,目标是最小化这种差异,以保持准确性。 - : 动态稀疏注意力计算所花费的时间。
- : 估计近似动态稀疏模式(即构建稀疏掩码 )所花费的时间。
- 目标是同时最小化准确性损失和总的计算时间(稀疏计算时间加上开销时间)。
4.2.2. 注意力稀疏模式的特征 (Attention Sparsity Exhibits Patterns)
通过对长上下文提示词的广泛分析,研究人员将注意力稀疏模式归类为三种:
-
A形模式 (A-shape pattern): 这种模式的注意力权重集中在序列的初始词元 (initial tokens) 和局部窗口 (local windows) 内。它通常表现出相对较高的稳定性,即模式的分布变化较小。类似于
StreamingLLM等方法中的全局-局部注意力。 -
垂直斜杠模式 (Vertical-Slash (VS) pattern): 注意力权重集中在特定的词元(垂直线
vertical lines) 和固定间隔的词元(斜杠线slash lines) 上。这种模式的特点是垂直线和斜杠线的位置会随着上下文内容动态变化,并且其稀疏性无法被简单的局部窗口或A形模式有效捕获。例如,模型可能需要关注提示词中的特定实体(垂直线)或周期性出现的结构(斜杠线)。 下图(原文 Figure 3a)可视化了不同注意力头的注意力权重,展示了这三种模式:
该图像是图表,展示了不同输入下注意力模式的可视化,包括A-shape、Vertical-Slash和Block-Sparse三种模式。图(b)展示了注意力的空间聚类情况,而图(c)展示了不同模式的注意力权重回忆与计算复杂度之间的关系。 -
块稀疏模式 (Block-Sparse pattern): 这是最具动态性的稀疏模式,注意力权重分布更为分散。尽管如此,它仍然表现出一定的空间聚类 (spatial clustering) 特征,即非零注意力权重往往聚集在较小的块 (blocks) 中。研究发现,在 128k 提示词中,非零注意力权重及其
top-k(最佳 k 个) 最近邻之间的距离通常集中在 5 左右(原文 Figure 3b),这表明了强烈的空间聚类。 下图(原文 Figure 3b)展示了在 128k 提示词中,非零注意力权重与其最近的top-k非零邻居之间的距离分布,验证了注意力权重的空间聚类特性。
该图像是图表,展示了不同输入下注意力模式的可视化,包括A-shape、Vertical-Slash和Block-Sparse三种模式。图(b)展示了注意力的空间聚类情况,而图(c)展示了不同模式的注意力权重回忆与计算复杂度之间的关系。论文还通过实验(原文 Figure 3c)证明,在相同的 FLOPs 预算下,这些识别出的模式相比于
Top-K等其他稀疏方法,能更高效地检索注意力分数,从而可能带来更好的准确性。
4.2.3. MInference 1.0 的三个步骤
MInference 的实施分为以下三个主要步骤:
4.2.3.1. 离线核感知最优稀疏模式搜索 (Offline Kernel-Aware Optimal Sparse Pattern Search)
这一步的目标是为每个注意力头 (attention head) 确定最合适的稀疏模式(A形、垂直斜杠或块稀疏)以及该模式的具体设置(例如,垂直/斜杠线的数量,或 top-k 块的数量)。“核感知 (Kernel-Aware)”意味着在搜索过程中,计算成本的估计是基于实际 GPU 内核的性能,而非纯粹的理论 FLOPs,这对于实现实际的加速至关重要。
算法 1: 核感知稀疏模式搜索 (Algorithm 1 Kernel-Aware Sparse Pattern Search)
Algorithm 1 Kernel-Aware Sparse Pattern Search Input: Q, K, V ∈ ℝ^S ×d_h, patterns p search space ρ, target FLOPs t, initialized search space σ # Build kernel-aware search space for i ← 1 to |σ| do ti ← FLOPs_in_kernel(σi) while |ti − t| > ϵ do σi ← ChangeSpace(σi, pi) ti ← FLOPs_in_kernel(σi) end while ρ ← ρ ∪ σi end for # Search for optimal head pattern p_best ← ∅ y ← Softmax(QKT / √d) for i ← 1 to |ρ| do yi ← SparseAttention(QKT / √d, ρi) p_best ← argmin(|yi − y|, p_best) end for return p_best
符号解释:
- : 查询 (Query)、键 (Key)、值 (Value) 矩阵,其中 是序列长度, 是头维度。
- patterns : 可用的稀疏模式集合(A形、垂直斜杠、块稀疏)。
- search space : 完整的核感知搜索空间,包含各种模式和设置的组合。
- target FLOPs : 为每个注意力头设定的目标浮点运算量预算。
- initialized search space : 初始化的模式和设置的候选集合。
- : 在实际 GPU 内核中,模式 的浮点运算量。
- : 一个小的阈值,用于判断实际 FLOPs 是否接近目标 FLOPs。
- : 根据目标 FLOPs 调整模式 的参数(例如,调整保留的垂直/斜杠线数量或块数量)。
- : 搜索到的最佳模式及其设置。
- : 密集注意力计算的输出。
- : 使用特定稀疏模式 进行稀疏注意力计算的输出。
- : 选择使稀疏注意力输出与密集注意力输出差异最小的模式作为最佳模式。
过程描述:
- 构建核感知搜索空间: 遍历初始候选模式和设置集合 。对于每个候选,通过
FLOPs_in_kernel函数计算其在实际 GPU 内核中的 FLOPs 。然后,通过反复调整模式参数 (ChangeSpace),直到其 FLOPs 接近target FLOPs t为止。所有经过 FLOPs 预算调整的候选模式被添加到最终的搜索空间 中。 - 搜索最佳头模式:
- 首先,计算完整、密集的注意力输出 作为基准。
- 然后,遍历搜索空间 中的每个候选模式 ,并计算其稀疏注意力输出 。
- 通过比较稀疏输出 与密集输出 之间的差异(以注意力输出的召回率 (recall) 为目标),选择差异最小的模式作为该注意力头的最佳模式 。
4.2.3.2. 动态构建稀疏索引和稀疏注意力计算 (Dynamic Build of Sparse Indices and Sparse Attention Calculation)
在推理阶段,根据离线搜索确定的模式和当前的输入,动态地构建稀疏掩码,并使用优化的 GPU 内核执行稀疏注意力计算。
-
A形头 (A-shape heads): 对于 A形模式的注意力头,稀疏掩码是静态的(例如,固定数量的全局词元和固定大小的局部窗口),因此在推理时无需额外开销来构建动态掩码,直接进行稀疏计算即可。
-
垂直斜杠头 (Vertical-Slash head): 这种模式的稀疏索引是动态构建的,主要涉及垂直线和斜杠线的识别。 算法 2: 垂直斜杠头 (Algorithm 2 Vertical-Slash Head)
Algorithm 2 Vertical-Slash Head Input: Q, K, V ∈ ℝ^S ×d_h, k_v, k_s ∈ ℕ # Approximate vertical and slash pattern (last_q = 64) Â ← softmax(Q_[-last_q:] K^⊤ / √d + m_casual) # Indices of top k_v vertical line, sum in vertical i_v ← argtopk(sum_v(Â), k_v) # Indices of top k_s slash line, sum in slash i_s ← argtopk(sum_s(Â), k_s) # Build sparse attention index i_vs ← sparseformat(i_v, i_s) # Final dynamic sparse attention scores (only index block) A ← softmax(sparse(QKT, i_vs)/√d) # Sparse mixed scores and values y ← sparse(AV, i_vs) return y符号解释:
- : 垂直线和斜杠线的数量,由离线搜索确定。
- : 用于近似计算的最后几个查询向量的数量,通常设置为 64。
- : 查询矩阵 的最后 行。
- : 估计的注意力矩阵,通过最后 个查询向量与所有键向量的乘积近似得到。
- : 因果掩码 (casual mask),确保查询只能关注到当前位置及之前的位置。
- : 沿垂直方向对 求和,用于识别垂直线的重要性。
- : 沿斜杠方向对 求和,用于识别斜杠线的重要性。
- : 返回
top-k值对应的索引。 - : 垂直线索引。
- : 斜杠线索引。
- : 将垂直线和斜杠线索引转换为稀疏格式 。
- : 根据稀疏索引 对 进行稀疏计算。
- : 最终的稀疏注意力矩阵。
- : 稀疏注意力输出。
过程描述:
- 模式近似: 利用最后 个查询向量 与所有键向量 进行矩阵乘法,得到一个较小的估计注意力矩阵 。这个小矩阵可以高效计算,并用于近似整个注意力矩阵的垂直和斜杠模式。
- 索引识别: 根据 ,通过对垂直方向和斜杠方向求和,并选择
top-k重要的线,识别出垂直线索引 和斜杠线索引 。 - 索引构建: 将 和 组合并格式化为稀疏索引 。
- 稀疏计算: 使用 对实际的 进行稀疏计算,然后进行
softmax和与 矩阵的乘法,得到稀疏注意力输出 。
-
块稀疏头 (Block-Sparse head): 这种模式的稀疏索引同样是动态构建的,主要涉及块级 (block-level) 注意力权重的识别。 算法 3: 块稀疏头 (Algorithm 3 Block-Sparse Head)
Algorithm 3 Block-Sparse Head Input: Q, K, V ∈ ℝ^S ×d_h, k_b ∈ ℕ # Approximate block-sparse pattern (block_size = 64) Q̂ ← MeanPooling(Q, block_size) K̂ ← MeanPooling(K, block_size) Â ← softmax(Q̂K̂^⊤ / √d + m_casual) # Indices of top k_b blocks i_b ← argtopk(Â, k_b) # Build sparse attention index i_b ← sparseformat(i_b) # Final dynamic sparse attention scores (only index block) A ← softmax(sparse(QKT, i_b)/√d) # Sparse mixed scores and values y ← sparse(AV, i_b) return y符号解释:
- : 要保留的块的数量,由离线搜索确定。
block_size: 块的大小,通常设置为 64。- : 对 矩阵进行均值池化 (mean pooling),将其下采样到块级表示。
- : 对 矩阵进行均值池化。
- : 估计的块级注意力权重,通过 计算得到。均值池化和矩阵乘法是可交换的,因此这可以有效近似实际注意力权重的块稀疏模式。
- :
top-k块的索引。 - : 将块索引转换为稀疏格式。
过程描述:
- 模式近似: 对 和 矩阵进行均值池化 (MeanPooling),将它们压缩到块级别,得到 和 。然后计算 ,得到估计的块级注意力权重 。这个过程通过在块级别进行计算,以最小的开销近似实际注意力权重的块稀疏模式。
- 索引识别: 从 中选择
top-k最重要的块,得到块索引 。 - 索引构建: 将 格式化为稀疏索引。
- 稀疏计算: 使用 对实际的 进行稀疏计算,然后进行
softmax和与 矩阵的乘法,得到稀疏注意力输出 。
-
GPU 内核实现细节 (Kernel Implementation Details): 为了确保稀疏计算在硬件上高效执行,研究人员开发了高度优化的 GPU 内核:
-
块稀疏 Flash Attention (Block-Sparse Flash Attention):
- 基于 Triton [TKC19] 版本的 FlashAttention 内核 [tri23]。
- 它通过将选定的块索引作为额外输入,使每个线程块 (thread block) 循环处理一行中的
top-K块。 - 理论加速比 ,其中 是序列长度, 是块大小, 是每行保留的块数。
-
垂直斜杠注意力内核 (Vertical-Slash Attention Kernel):
-
包括两个自定义内核:垂直斜杠稀疏索引内核和垂直斜杠稀疏 FlashAttention 内核。
-
垂直斜杠稀疏索引内核 (Algorithm 4): Algorithm 4 Vertical-Slash Sparse Index Kernel Input: vertical indexes i_v ∈ ℕ^k_v, slash indexes i_s ∈ ℕ^k_s # Sort vertical and slash indexes i_v ← IncrementalSort(i_v) i_s ← DescendingSort(i_s) # Calculate block number (block_size B) N ← ⌈S / B⌉ # Initialize outputs block count c_blk ∈ ℕ^N, block index i_blk ∈ ℕ^(N ×k_s), column count c_col ∈ ℕ^N, column index i_col ∈ ℕ^(N ×k_v) # Parallelized in GPU for i ← 1 to N do j_v ← 1 # Find the first slash line that crosses the row j_s ← biset_left(i_s, i × B) # Define the range by slash index r_start ← (i - 1) × B - i_s^(j_s) r_end ← i × B - i_s^(j_s) # Merge points (vertical indexes) and ranges (slash indexes) while s_v ≤ k_s do if j_v ≤ k_v and i_v^(j_v) < r_end then # Record the point if not in the range if i_v^(j_v) < r_start then c_col^i ← c_col^i + 1 i_col^(i, c_col^i) ← i_v^(j_v) end j_v ← j_v + 1 else s_v ← s_v + 1 # Update the range if (i - 1) × B - i_s^(j_s) > r_end then # Record the last range s ← r_start while s < r_end do c_blk^i ← c_blk^i + 1 i_blk^(i, c_blk^i) ← s s ← s + B end while # Calculate the new range r_start ← (i - 1) × B - i_s^(j_s) r_end ← i × B - i_s^(j_s) else # Extend the range r_end ← r_end + B end end end while # Record the last range s ← r_start while s < r_end do c_blk^i ← c_blk^i + 1 i_blk^(i, c_blk^i) ← s s ← s + B end while end for return c_blk, i_blk, c_col, i_col 符号解释:
- : 垂直索引 (vertical indexes)。
- : 斜杠索引 (slash indexes)。
- : 块大小 (block size)。
- : 序列中的块数量。
- : 每行(块)的块计数。
- : 每行(块)的块索引。
- : 每行(块)的列计数。
- : 每行(块)的列索引。
IncrementalSort,DescendingSort: 增量排序和降序排序函数。biset_left: 查找第一个斜杠线与给定行交叉的位置。- : 定义斜杠索引的范围。 过程描述: 这个内核并行处理每个查询行对应的块。它首先对垂直和斜杠索引进行排序。然后,对于每个块,它通过合并垂直点索引和斜杠范围索引来构建稀疏索引。由于斜杠线段可以被方块掩盖,所以注意力掩码是块和列的混合。索引构建的时间复杂度为 。 下图(原文 Figure 7)展示了垂直斜杠模式的动态稀疏掩码:
该图像是示意图,展示了在摘要任务中使用 LLaMA-3-8B 模型的垂直斜线模式的动态稀疏掩码。黄色区域表示计算部分,斜线部分使用 的块,而垂直线部分使用 的块。 -
垂直斜杠稀疏 FlashAttention 内核 (Algorithm 5): Algorithm 5 Vertical-Slash Sparse FlashAttention Kernel Input: Q, K, V ~ ∈ ℝ^S ×d_h, block count c_blk ~ ∈ ℕ^N, block index i_blk ∈ ℕ^(N ×k_s), column count c_col ∈ ℕ^N, column index i_col ∈ ℝ^(N ×k_v) Scale τ ← √1 Initialize O ← (0)^(S ×d_h) ∈ ℝ^(S ×d_h) # Parallelized in GPU for i ← 1 to N do Load Q_chip ← Q^(i×B:(i+1)×B) ∈ ℝ^(B ×d_h) Initialize O_chip ← (0)^(B ×d_h) ∈ ℝ^(B ×d_h) Initialize m ← (−inf)^B ∈ ℝ^B Initialize l ← (0)^B ∈ ℝ^B # Loop through block indexes: block sparse flash attention fr j ← 1 to c_blk^i do s ← i_blk^(i,j) K_chip ← K^(s:s+B) ∈ ℝ^(B ×d_h) Load V_chip ← V^(s:s+B) ∈ ℝ^(B ×d_h) S ← τ Q_chip K_chip^T S ← mask(S) m_new ← max(m_i, rowmax(S)) RB S ← S − m_new P ← exp(S) t_new ← rowsum(S)) α ← exp(m_i − m_new) l_i ← αl_i + l_new O_chip ← αO_chip + O_chip end for # Loop through column indexes : PIT sparse flash attention j ← 0 while j < c_col do cols ← i_col^(i,j:j+B) ∈ ℕ^B Load K_chip ← K^(cols) ∈ ℝ^(B ×d_h) Load V_chip ← V^(cols) ∈ ℝ^(B ×d_h) S ← τ Q_chip K_chip^T S ← mask(S) m_new ← max(m_i, rowmax(S)) RB r ← S−mx(m P ← exp(s) new ← rowsum(S)) α ← exp(m_i − m_new) l_i ← αl_i + l_new O_chip ← αO_chip + O_chip j ← j + B end while # Write outputs O_chip ← diag(l)−1O_chip Save O_i ← O_chip end for 符号解释:
- : 从垂直斜杠稀疏索引内核获得的块计数、块索引、列计数和列索引。
- : 缩放因子。
- : 最终的注意力输出矩阵。
- : 当前处理的查询、键、值矩阵的块。
m, l: 用于 FlashAttention 的 (最大值) 和 (归一化因子) 累积变量。- : 注意力得分矩阵。
mask(S): 应用因果掩码等。- : 新计算的 和 。
- :
exp因子,用于更新 和 。 过程描述: 这是一个混合内核,结合了块稀疏注意力和 PIT [ZJZ+23] 稀疏注意力。它并行处理每个查询块,首先循环遍历由i_blk定义的块索引(执行块稀疏 FlashAttention),然后循环遍历由i_col定义的列索引(执行 PIT 稀疏 FlashAttention)。PIT 是一种将稀疏数据加载到密集计算块的技术,通过排列不变变换 (Permutation Invariant Transformation)。这个混合内核的延迟与块和列的总面积呈线性关系。
-
-
4.2.4. 单一 A100 GPU 上的实现细节 (Single A100 Implementation Details)
为了在单个 A100 (80G) GPU 上运行 1M (百万) 词元提示词推理,研究人员进行了以下优化:
- 张量拆分 (Tensor Splitting): 将注意力计算按头 (head) 进行拆分,并将多层感知器 (Multi-Layer Perceptron, MLP) 按序列维度进行拆分。在长上下文场景中,计算是瓶颈,这种拆分能够保持 GPU 利用率在 100%,且拆分带来的开销可忽略不计。
- 减少中间变量 (Reduction of Intermediate Variables): 最小化中间变量的分配,通过移除注意力掩码 (attention mask) 并直接在内核中实现因果掩码 (causal mask) 逻辑。
- 消除不必要的计算 (Elimination of Unnecessary Computations): 在预填充阶段,只有对应于最后一个词元的
logits(对数几率) 对LM Head Linear层有意义。因此,只保留了最后一个词元的LM Head Linear层计算。
5. 实验设置
5.1. 数据集
- InfiniteBench [ZCH+24]:
- 来源与特点: 这是一个专门为评估 LLM 长上下文处理能力而设计的基准测试集。它包含 10 个任务,涵盖了长文本的各种处理方面,如:
- 检索任务 (Retrieval tasks): PassKey retrieval (密码键检索), Number retrieval (数字检索), KV retrieval (键值检索)。
- 现实任务 (Realistic tasks): Question-answering (问答), Coding (代码), Dialogue (对话), Summarization (摘要)。
- 规模: 平均上下文长度约为 214K 词元,总共有 3,992 个示例。
- 选择原因: InfiniteBench 包含了多种任务类型,能够全面评估 MInference 在不同长上下文场景下的有效性。
- 来源与特点: 这是一个专门为评估 LLM 长上下文处理能力而设计的基准测试集。它包含 10 个任务,涵盖了长文本的各种处理方面,如:
- RULER [HSK+24]:
- 来源与特点: 这是一个具有挑战性的合成长上下文基准测试套件,包含 4 个类别和 13 个复杂任务。
- 检索类别 (Retrieval category): 包括 Single Needle-in-a-Haystack (S-NIAH, 单针在草堆)、Multi-keys Needle-in-a-Haystack (MK-NIAH, 多键针在草堆)、Multi-values Needle-in-a-Haystack (MV-NIAH, 多值针在草堆) 和 Multi-queries Needle-in-a-Haystack (MQ-NIAH, 多查询针在草堆) 任务。
- 多跳追踪类别 (Multi-hop Tracing category): 如 Variable Tracking (VT, 变量追踪)。
- 聚合类别 (Aggregation category): 如 Common Words Extraction (CWE, 常见词提取) 和 Frequent Words Extraction (FWE, 频繁词提取)。
- 问答类别 (Question Answering, QA): 通过添加干扰段落扩展现有短上下文问答数据集。
- 规模: 包含不同提示词长度的子集,最长可达 128k 词元。每个长度有 2,600 个示例。
- 选择原因: RULER 能够测试模型在多跳推理、聚合和复杂问答等更复杂长上下文场景下的性能,有助于揭示 MInference 在复杂任务中的真实潜力。
- 来源与特点: 这是一个具有挑战性的合成长上下文基准测试套件,包含 4 个类别和 13 个复杂任务。
- Needle In A Haystack [Kam23]:
- 来源与特点: 这是一个流行的长上下文检索基准测试,通过在大量冗余文本(“草堆”)中插入一个特定信息(“针”),来评估 LLM 检索和利用关键信息的能力。测试通常会调整“针”的位置和上下文长度。
- 规模: 本文将该任务扩展到 1M 词元上下文长度,包含 750 个示例。
- 选择原因: 该任务直接量化了模型处理超长上下文时关键信息检索的能力,是评估长上下文 LLM 的标准方法。
- PG-19 [RPJ+20]:
- 来源与特点: 这是一个用于长上下文语言建模任务的数据集,包含长度可达 500K 词元的长文本。
- 规模: 本文使用 PG-19 中 1,000 个长度超过 100K 词元的随机样本进行评估。
- 选择原因: 困惑度 (perplexity) 是衡量语言模型性能的关键指标,PG-19 能够测试模型在长文本生成和理解中的语言建模能力。
5.2. 评估指标
-
准确率 (Accuracy):
- 概念定义: 准确率是最常见的分类和检索任务评估指标之一,它衡量模型预测结果与真实标注数据 (Ground Truth) 一致的比例。在检索任务中,通常指模型能否精确地找到并输出目标信息;在问答任务中,则指模型是否能正确回答问题。
- 数学公式:
- 符号解释:
- : 模型在给定测试集上做出正确预测的样本数量。
- : 测试集中的总样本数量,即模型做出的总预测次数。
- 在本文中的应用: 用于 InfiniteBench (各个子任务,如 Retr.PassKey, Retr.Num, Retr.KV 等)、RULER 和 Needle In A Haystack 任务。具体任务可能根据其性质采用不同的准确率计算方式(如精确匹配、F1 分数等)。
-
困惑度 (Perplexity, PPL):
- 概念定义: 困惑度是衡量语言模型质量的指标,尤其在语言建模任务中。它量化了模型预测序列中下一个词元 (token) 的不确定性或“惊讶”程度。困惑度越低,表示模型对测试数据越“不困惑”,即其预测与实际文本分布越吻合,模型表现越好。
- 数学公式: 对于一个包含 个词元的序列 ,其困惑度定义为:
- 符号解释:
- : 给定的词元序列。
- : 序列 的词元数量。
- : 序列 在语言模型下的联合概率。
- : 在给定前面
i-1个词元的情况下,模型预测第 个词元 的条件概率。
- 在本文中的应用: 用于 PG-19 数据集上的长上下文语言建模任务。
-
延迟 (Latency):
- 概念定义: 在 LLM 推理场景中,延迟通常指从模型接收到整个输入提示词 (prompt) 到开始生成第一个输出词元 (token) 所花费的时间。这正是预填充阶段 (pre-filling stage) 的主要耗时。低延迟对于提升用户体验至关重要。
- 数学公式: 延迟通常通过直接测量时间来获得,没有统一的标准化公式,但可以概念性地表示为:
- 符号解释:
- : 预填充阶段计算完成的时间点。
- : 预填充阶段计算开始的时间点。
- 在本文中的应用: 用于评估 MInference 在不同上下文窗口下对预填充阶段加速的效果,并与基线方法进行比较。
5.3. 对比基线
论文将 MInference 方法与多种训练无关 (training-free) 的稀疏注意力方法进行了比较:
- StreamingLLM [XTC+24]: 这是一个专注于高效流式推理的方法,其核心思想是使用 A形模式的注意力,即结合固定数量的全局词元 (global tokens) 和固定大小的局部窗口 (local windows) 来处理长序列。
- 代表性: 它代表了一种广泛采用的、在不改变模型结构下扩展上下文窗口的方法,是 MInference 中 A形模式的基石。
- 参数: 使用 1k 全局词元和 4k 局部窗口。
- StreamingLLM w/ dilated [BPC20]: 在 StreamingLLM 的基础上,引入了扩张 (dilated) 注意力窗口。扩张注意力允许模型跳过一些词元,从而在不增加计算量的同时扩大感受野。
- 代表性: 代表了在局部窗口中引入跳跃连接以捕获更长距离依赖的思路。
- 参数: 使用 1k 全局词元和 8k 扩张注意力窗口,扩张间隔为 1。
- StreamingLLM w/ strided [CGRS19]: 结合了局部窗口和扩张注意力,通常指在局部窗口内进行步幅 (strided) 跳跃,或者在不同层使用不同步幅。
- 代表性: 代表了通过混合局部和步幅/扩张注意力来平衡局部信息和长距离依赖的方法。
- 参数: 使用 1k 全局词元、2k 局部窗口和 4k 扩张注意力窗口,扩张间隔为 1。
- InfLLM [XZH+24]: 通过引入一个内存单元 (memory unit) 来处理流式的长序列,旨在揭示 LLM 理解极长序列的内在能力,无需训练。
- 代表性: 代表了利用外部记忆机制来扩展上下文的方法,无需对模型进行微调。
- 参数: 使用 128 全局词元和 8k 局部窗口。
- Ours w/ static: 这是 MInference 的一个消融版本,它在垂直斜杠 (Vertical-Slash) 和块稀疏 (Block-Sparse) 注意力头中使用了静态稀疏索引 (static sparse indices),而不是动态构建。
-
代表性: 用于验证 MInference 中动态索引构建的重要性。
共同策略: 所有基线方法在预填充阶段执行稀疏计算,但在解码阶段保留密集计算。这确保了比较的公平性,因为 MInference 也只在预填充阶段应用稀疏化。
-
5.4. 实现细节
- 基座模型 (Base Models): 实验使用了以下最先进的长上下文 LLM:
- LLaMA-3-8B-Instruct-262k: LLaMA-3 的变体,通过 NTK 感知插值 (NTK-aware interpolation) 和 Ring Attention (环形注意力) 进行少量微调,在长上下文评估中表现出色。
- LLaMA-3-8B-Instruct-1048k: 类似于 LLaMA-3-8B-Instruct-262k,但支持高达 1M 词元的上下文长度。
- GLM-4-9B-1M [GZX+24]: GLM 系列的最新模型,上下文窗口达 1M。
- Yi-9B-200K [YCL+24]: 在长上下文性能和通用能力之间取得平衡的 LLM。
- Phi-3-Mini-128K [AJA+24]: 小型但功能强大的 LLM,通过 LongRoPE [DZZ+24] 支持 128K 上下文。
- Qwen2-7B-128K [BBC+23]: Qwen 系列的更新版本,支持 128K 上下文。
- 解码策略 (Decoding Strategy): 所有实验均使用贪婪解码 (greedy decoding),以确保结果的稳定性。
- 实现框架 (Implementation Framework): MInference 的内核实现基于 PyTorch,并构建在 FlashAttention [Dao24]、Triton [TKC19] 和动态稀疏编译器 PIT [ZJZ+23] 之上。
- 目标 FLOPs (Target FLOPs): 在核感知最优稀疏模式搜索中,目标 FLOPs 被设定为与 A形模式中 1k 全局词元和 4k 局部窗口的 FLOPs 相同。
- 特定参数 (Specific Parameters):
- : 用于垂直斜杠模式中近似计算的查询向量数量。
block_size = 64: 用于块稀疏模式中均值池化 (MeanPooling) 和块级计算的块大小。
- 硬件 (Hardware): 延迟实验在单个 Nvidia A100 GPU 上进行。
- 数据类型 (Data Type): 使用 bfloat16 (Brain Floating Point 16) 格式。
- 搜索空间细节:
-
ChangeSpace的步长设置为 50。 -
搜索空间如原文 Table 7 所示: 以下是原文 Table:Kernal-aware optimal head pattern searc space. In this context, A-shape represents the global tokens and local window number, Vertical-Slash represents the Top-K number of vertical and diagonal lines, and Block-Sparse represents the Top-K number of blocks retained. 的结果:
Patterns Search Space A-shape {(1024, 4096)} Vertical-Slash {(30, 2048), (100, 1800), (500, 1500), (3000, 200)} Block-Sparse {100} -
使用来自 KV 检索合成数据的一个 30k 词元输入样本作为参考示例 (validation set) 进行模式搜索。
-
搜索时间大约为 15 分钟(单 A100)。
-
LLaMA-3-8B-Instruct-262K 和 LLaMA-3-8B-Instruct-1M 使用相同的最优稀疏模式配置。
-
6. 实验结果与分析
本节详细解读 MInference 在各种长上下文基准测试上的表现,包括其有效性、效率、消融实验和与现有方法的对比。
6.1. 核心结果分析
6.1.1. InfiniteBench 任务表现
以下是原文 Table 2: Performance of different methods with different base models on InfiniteBench 的结果:
| Methods | En.Sum | En.QA | En.MC | En.Dia | Zh.QA | Code.Debug | Math.Find | Retr.PassKey | Retr.Num | Retr.KV | Avg. |
| LLaMA-3-8B-262K | 20.2 | 12.4 | 67.3 | 6.0 | 12.9 | 22.1 | 26.6 | 100.0 | 100.0 | 14.4 | 38.2 |
| StreamingLLM | 21.0 | 8.2 | 40.2 | 10.0 | 10.4 | 25.9 | 30.0 | 86.8 | 5.1 | 0.8 | 23.8 |
| StreamingLLM w/ dilated | 20.1 | 9.4 | 44.5 | 15.5 | 11.2 | 20.5 | 27.5 | 5.0 | 87.5 | 0.5 | 24.2 |
| StreamingLLM w/ strided | 17.3 | 8.2 | 27.5 | 14.5 | 11.2 | 19.5 | 27.5 | 4.0 | 2.1 | 1.0 | 13.3 |
| InfLLM | 24.1 | 7.8 | 45.0 | 6.0 | 11.4 | 19.5 | 32.9 | 100.0 | 100.0 | 1.2 | 34.8 |
| Ours w/ static | 19.9 | 8.6 | 43.2 | 3.5 | 8.9 | 20.6 | 25.1 | 92.4 | 96.3 | 0.2 | 31.9 |
| Ours | 20.5 | 12.9 | 65.9 | 7.5 | 12.5 | 22.3 | 33.1 | 100.0 | 100.0 | 12.8 | 38.8 |
| Yi-9B-200K | 8.2 | 10.6 | 64.2 | 1.0 | 17.3 | 21.3 | 99.8 | 100.0 | 28.8 | 37.5 | |
| StreamingLLM | 5.4 | 14.2 | 38.0 | 4.0 | 18.8 | 18.8 | 23.4 | 22.3 | 39.2 | 6.1 | 1.6 |
| StreamingLLM w/ dilated | 5.7 | 4.2 | 15.0 | 0.0 | 18.2 | 0.0 | 2.9 | 3.1 | 0.0 | 0.0 | 4.2 |
| StreamingLLM w/ strided | 6.1 | 4.5 | 9.8 | 0.0 | 16.9 | 0.0 | 1.5 | 0.0 | 0.0 | 4.6 | |
| InfLLM | 6.3 | 13.0 | 45.9 | 2.5 | 21.5 | 20.6 | 85.3 | 88.1 | 1.4 | 31.9 | |
| Ours w/ static | 5.8 | 12.6 | 48.5 | 3.0 | 12.6 | 34.6 | 20.8 | 25.1 | 60.9 | 38.5 | 1.0 |
| Ours | 7.9 | 11.2 | 64.2 | 1.0 | 17.9 | 24.1 | 23.1 | 99.5 | 100.0 | 27.6 | 37.7 |
| GLM-4-9B-1M | 28.3 | 9.7 | 68.6 | 39.5 | 12.1 | 29.4 | 38.9 | 100.0 | 100.0 | 41.0 | 46.7 |
| StreamingLLM | 27.7 | 6.4 | 40.2 | 12.5 | 10.8 | 27.7 | 21.1 | 97.1 | 39.4 | 25.6 | 0.6 |
| nflLM | 28.0 | 7.3 | 45.0 | 14.0 | 10.7 | 27.9 | 98.0 | 100.0 | 2.6 | 37.3 | |
| Ours | 28.8 | 9.6 | 68.6 | 38.5 | 12.0 | 30.7 | 39.1 | 100.0 | 100.0 | 43.0 | 47.0 |
分析:
- MInference 的优越性: 在 LLaMA-3-8B-262K、Yi-9B-200K 和 GLM-4-9B-1M 三种模型上,MInference (Ours) 在
InfiniteBench的平均得分 (Avg.) 上都达到了最佳性能,甚至在某些任务上(如 LLaMA-3-8B 的En.QA,En.MC,Math.Find)超越了原始的全注意力 (full attention) 模型性能。这表明 MInference 在提供显著加速的同时,能够保持甚至在某些情况下提升模型的准确性。 - 检索任务的强劲表现:
Retr.PassKey和Retr.Num任务对长上下文 LLM 的关键信息检索能力要求很高。MInference 在这些任务上几乎总是达到 100% 的准确率,而StreamingLLM及其变体,以及Ours w/ static版本,在这些任务上表现不佳,甚至接近零(例如StreamingLLM在Retr.KV上)。这强调了 MInference 动态稀疏策略对于处理需要精准检索信息任务的重要性。 - 基线方法的局限性:
StreamingLLM及其w/ dilated和w/ strided变体在许多任务上(尤其是检索任务)表现显著低于全注意力模型和 MInference。这说明简单的固定稀疏模式难以适应复杂多变的长上下文注意力需求。 - 动态策略的必要性:
Ours w/ static(使用静态稀疏索引的 MInference 版本) 在Retr.KV等动态任务上表现极差,平均得分也显著低于完整版 MInference。这直接证明了 MInference 中动态构建稀疏索引的必要性和有效性。 - 不同任务的适应性: MInference 不仅在自然语言任务(如摘要、问答、代码)中表现出色,还能保持模型在检索相关任务中的原始性能。这表明 MInference 的稀疏模式设计能够捕获不同任务的注意力特征。
6.1.2. RULER 任务表现
以下是原文 Table 3: Performance of different models and different methods on RULER evaluated at lengths from 4k to 128k. 的结果:
| Methods | Claimed | Effective | 4K | 8K | 16K | 32K | 64K | 128K | Avg. |
| LLaMA-3-8B-262K | 262K | 16K | 97.2 | 91.8 | 87.3 | 80.8 | 77.4 | 72.2 | 84.4 |
| StreamingLLM | - | 4K | 97.2 | 38.1 | 37.5 | 17.2 | 14.2 | 9.4 | 35.0 |
| StreamingLLM w/ dilated | - | <4K | 23.4 | 0.7 | 1.4 | 18.8 | 16.5 | 15.6 | 12.7 |
| StreamingLLM w/ strided | - | <4K | 2.0 | 0.7 | 0.6 | 0.6 | 0.7 | 1.3 | 1.0 |
| InfLLM | - | 4K | 89.4 | 79.8 | 70.1 | 55.6 | 43.0 | 39.5 | 62.9 |
| Ours | 32K | 97.7 | 91.2 | 88.5 | 85.0 | 82.3 | 77.6 | 87.0 | |
| Yi-9B-200K | 200K | 8K | 91.9 | 90.2 | 78.8 | 76.3 | 68.1 | 62.9 | 78.1 |
| StreamingLLM | - | 4K | 91.9 | 37.8 | 33.9 | 18.6 | 13.0 | 12.8 | 34.3 |
| StreamingLLM w/ dilated | - | <4K | 44.8 | 42.8 | 38.5 | 29.8 | 26.8 | 23.9 | 34.4 |
| StreamingLLM w/ strided | <4K | 2.6 | 0.7 | 0.6 | 0.6 | 1.2 | 0.5 | 1.1 | |
| fLM | - | <4K | 80.3 | 83.9 | 60.7 | 45.2 | 38.6 | 30.2 | 56.5 |
| Ours | - | 8K | 92.3 | 89.7 | 79.0 | 73.8 | 64.7 | 56.9 | 74.7 |
| GLM-4-9B-1M | 1M | 64K | 93.8 | 91.6 | 89.3 | 87.4 | 85.2 | 80.8 | |
| StreamingLLM | - | 4K | 93.8 | 66.9 | 58.5 | 51.4 | 45.9 | 39.1 | 88.0 |
| InfLLM | - | 8K | 94.7 | 89.5 | 76.4 | 66.5 | 56.8 | 53.5 | 59.3 |
| Ours | - | 64K | 94.6 | 93.1 | 91.0 | 89.6 | 85.5 | 84.0 | 72.9 |
分析:
- 长上下文能力保持: MInference 在 RULER 基准测试上有效地保持了长上下文性能,即使面对复杂的多跳或聚合任务。
- 有效上下文窗口 (Effective Context Window) 扩展:
- 对于 LLaMA-3-8B-262K,MInference 将有效上下文窗口(性能高于 85%)从基线模型的 16K 扩展到 32K,并且在 32K 长度上的性能甚至超过了全注意力模型(85.0% vs 80.8%)。
- 对于 GLM-4-9B-1M,MInference 实现了 64K 的有效上下文窗口,再次优于全注意力模型在 64K 和 128K 长度上的表现。
- 基线的性能退化:
StreamingLLM及其变体,以及InfLLM,在上下文长度增加时性能急剧下降,尤其是在 8K 甚至更长的情况下,其有效上下文窗口远低于 MInference 和原始模型。这再次表明固定稀疏模式或简单的记忆机制不足以处理RULER这类复杂长上下文任务。 - MInference 优势显著: MInference 在所有测试长度(从 4K 到 128K)上都显著优于所有基线稀疏方法。这突出了其在保持长上下文能力方面的强大实力。
6.1.3. 语言建模 (PG-19) 表现
下图(原文 Figure 5)展示了在 PG-19 数据集上使用不同模型和方法的困惑度 (Perplexity, PPL) 结果:
分析:
- 低困惑度: MInference 在
PG-19语言建模任务上取得了最佳结果,其困惑度曲线最接近全注意力 (Full Attention) 基线。 - 优于其他稀疏方法: 对于 100K 词元提示词,MInference 的困惑度仅比全注意力高 0.2,但比
StreamingLLM在 Yi-9B-200K 模型上低 0.25,在 LLaMA-3-262K 模型上低 0.75。这表明 MInference 在实现加速的同时,能更好地保持语言模型的生成质量。 - 基线方法表现不佳:
StreamingLLM及其变体在长上下文下困惑度显著升高,表示它们对文本的预测能力下降。
6.1.4. Needle In A Haystack 任务表现
下图(原文 Figure 1a)展示了在 LLaMA-3-8B-1M 模型中,使用 MInference 方法的 Needle In A Haystack 结果。下图(原文 Figure 6)展示了 StreamingLLM 的结果。下图(原文 Figure 8)展示了 InfLLM 的结果。下图(原文 Figure 9)展示了在 GLM-4-9B-1M、Yi-9B-200K、Phi-3-Mini-128K 和 Qwen2-7B-128K 模型上的结果。
该图像是图表,展示了在 LLaMA-3-8B-1M 模型中,使用 MInference 方法的 Needle In A Haystack 结果。在 (a) 中,深度百分比随着上下文长度的变化而变化;在 (b) 中,展示了 MInference 在处理 1M 上下文时相较于 FlashAttention-2 的延迟减少,最高可达 10 imes 的加速。
该图像是一个热图,展示了在 Needle In A Haystack 任务中,使用 StreamingLLM 进行 LLaMA-3-8B-1M 模型的预填充结果。图中横轴为上下文长度,单位为 tokens,纵轴为深度百分比,颜色变化表示不同的准确率。
该图像是一个热图,展示了在使用 InfLLM 的 LLaMA-3-8B-1M 模型中,针对不同上下文长度(从 1k 到 1M)的 Needle In A Haystack 任务的深度百分比。横轴表示上下文长度,纵轴表示深度百分比。
分析:
- MInference 的鲁棒性: MInference 有效地保留了在不同位置、不同上下文窗口(从 1k 到 1M 词元)处理信息的能力。无论是“针”在提示词开头、中间还是结尾,MInference 都能稳定地检索到信息,这与全注意力模型的表现非常接近。
- 基线的性能下降: 相比之下,
StreamingLLM和InfLLM等方法,当关键信息超出其设定的全局词元和局部窗口范围时,性能会急剧下降,导致在某些深度和上下文长度下几乎无法检索到信息。 - 微小性能提升: 在某些情况下(如 Yi-9B-200K 和 Phi-3-Mini-128K 在 100K 左右的上下文长度),MInference 甚至略微提高了性能。这可能归因于稀疏性带来的噪声过滤或更聚焦的注意力。
6.1.5. 延迟与效率分析
下图(原文 Figure 1b)展示了 MInference 在处理 1M 上下文时相较于 FlashAttention-2 的延迟减少,最高可达 10 倍的加速。下图(原文 Figure 10)展示了三种注意力模式和 FlashAttention 在不同上下文窗口中的单注意力内核延迟分解:
该图像是图表,展示了在 LLaMA-3-8B-1M 模型中,使用 MInference 方法的 Needle In A Haystack 结果。在 (a) 中,深度百分比随着上下文长度的变化而变化;在 (b) 中,展示了 MInference 在处理 1M 上下文时相较于 FlashAttention-2 的延迟减少,最高可达 10 imes 的加速。
分析:
- 显著的加速比: MInference 在 100K、300K、500K 和 1M 词元上下文窗口下分别实现了 1.8x, 4.1x, 6.8x 和 10x 的加速。对于 1M 词元提示词,单 A100 上的预填充延迟从 30 分钟减少到 3 分钟。
- 动态索引构建开销: 动态稀疏索引构建的开销约为总时间的 5%-20%。虽然存在开销,但相比于稀疏计算带来的加速,这个开销是可接受的,尤其是在长上下文场景下,该开销的比例会逐渐降低。
- 模式的效率差异:
- 块稀疏 (Block-Sparse): 是最快的模式,在 1M 词元下比
FlashAttention快 30 倍。 - 垂直斜杠 (Vertical-Slash): 其次,快 13 倍。
- A形 (A-shape): 性能略低于垂直斜杠,但在 1M 词元下比垂直斜杠慢 50%(从图 10 看,A-shape 是 164ms,Vertical-Slash 大约 100ms 左右,所以 A-shape 比 Vertical-Slash 慢约 64%)。
- 块稀疏 (Block-Sparse): 是最快的模式,在 1M 词元下比
- 硬件通用性: 由于 MInference 的内核是基于 Triton 实现的,这使得它易于移植到其他设备(如 H100 或 MI300X),并有望实现类似的加速效果。
6.1.6. 与 KV 缓存压缩方法的集成
以下是原文 Table 5: Performance of different methods on InfiniteBench using SnapKV in the decoding stage. 的结果:
| Methods | En.Sum | En.QA | En.MC | En.Dia | Zh.QA | Code.Debug | Math.Find | Retr.PassKey | Retr.Num | Retr.KV | Avg. |
| LLaMA-3 w/ SnapKV | 18.0 | 11.8 | 65.5 | 2.5 | 12.0 | 21.3 | 26.6 | 100.0 | 100.0 | 1.8 | 36.0 |
| Ours w/ SnapKV | 18.9 | 11.7 | 66.4 | 6.5 | 12.1 | 21.8 | 33.1 | 100.0 | 100.0 | 2.0 | 37.3 |
分析:
- MInference 与
SnapKV[LHY+24](一种最先进的 KV 缓存压缩方法)相结合时,性能几乎没有变化,平均得分甚至略有增加 (36.0 -> 37.3)。 - 这证明了 MInference 与其他 KV 缓存压缩技术的兼容性,表明它可以作为 LLM 服务 (serving) 优化堆栈中的一个组成部分,进一步提升实际应用价值。
6.1.7. 在更大 LLM 上的扩展性
以下是原文 Table 6: Performance of different methods using LLaMA-3-70B-Instruct-262K on InfiniteBench . 的结果:
| Methods | |En.Sum | En.QA | En.MC | En.Dia | Zh.QA | Code.Debug | Math.Find | Retr.PassKey | Retr.Num | Retr.KV| | Avg. |
| LLaMA-3-70B-262K | 20.7 | 10.3 | 84.2 | 9.5 | 14.0 | 33.2 | 61.7 | 97.0 | 100.0 | 34.0 | 46.5 |
| StreamingLLM | 20.5 | 8.5 | 52.0 | 10.0 | 12.6 | 27.4 | 61.1 | 14.0 | 10.0 | 0.0 | 21.6 |
| InfLLM | 24.1 | 8.1 | 57.0 | 10.0 | 12.9 | 27.4 | 52.3 | 100.0 | 100.0 | 0.0 | 39.2 |
| Ours | 20.6 | 10.1 | 83.4 | 10.0 | 14.1 | 34.1 | 61.9 | 100.0 | 100.0 | 39.0 | 47. |
分析:
- MInference (Ours) 在 LLaMA-3-70B-1M 等更大模型上依然保持了强大的性能,在
InfiniteBench的平均得分上优于StreamingLLM和InfLLM,并与全注意力模型的性能非常接近 (47.0 vs 46.5)。 - 在动态任务如
KV retrieval中,MInference 甚至可以匹配或略微提升全注意力模型的性能 (39.0 vs 34.0)。这表明 MInference 的动态稀疏策略对于大型模型同样有效且稳定。 StreamingLLM和InfLLM在Retr.KV任务上表现不佳 (0.0%),再次凸显了 MInference 在处理动态检索任务上的优势。
6.1.8. 案例分析 (Case Study)
论文通过案例研究进一步展示了 MInference 在摘要 (Summarization) 和 KV 检索 (KV Retrieval) 任务中的优势。
表 9: 摘要任务生成结果对比 以下是原文 Table 9: Comparison of generated results by different methods using LLaMA-3-8B-Instruct-262K in summarization task . 的结果:
| Methods | Summary |
| Original | The story follows the Bronwyn family and their guests as they prepare for a daytrip to the Kindergarten, a mysterious destination that holds great significancefor them. The story is told from multiple perspectives, including that of PeteLucy, who is staying with the family and is struggling to find her place in theirlives. As the day progresses, the family members prepare for their trip, and Petereflects on her own past experiences with the Bronwyns. The story ends with thefamily setting off on their journey, leaving Pete alone on the lawn to contemplateher own feelings and the meaning of life. The Kindergarten is a symbol of thefamily's past and their memories, and the story explores themes of nostalgia,loss, and the passage of time. Ultimately, the story suggests that even though thepast cannot be changed, it can still be celebrated and remembered. |
| StreamingLLM | The story follows a family on a boat trip to the Kindergarten, a school for boys.The father, Mr Bronwyn, is reading a book and seems to be in a trance. Thechildren, Seb and Annalisa, are excited about the trip but also a little scared.They pass by a group of fishermen and Mr Bronwyn shares out sandwiches. Theycome across a spot where three men were drowned and Mr Bronwyn takes outhis watch, making a calculation. They reach the Kindergarten and Mr Bronwynjumps out of the boat, holding a parcel. The story ends with Pete Lucy, who hasbeen watching the family from the shore, feeling relieved that the trip is over andthat Mr Bronwyn has reached his destination. The story is told from multipleperspectives, including Pete Lucy's, who is watching the family from the shoreand reflecting on his own life and relationships. |
| StreamingLLMw/ dilated | Accccccccccccccccclccccccccccccccccccccccccccccccccccccccccccccc|ccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccCCCcccc((((cccccc(((cccc((((cccccc(((( |
| StreamingLLMw/ strided | The GraphicsUnit36.OOAKA new ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( ( (clccccccccccccccccccccccccccccccccccccccccccccccclcccccccccccccccccccccccccccccccccccccccccccccc|l|ccccc| |
| Ours w/ static | def solverome2 def solverome2 def solverome2 def solverome2 defsolverome2 def solverome2 def solverome2 def solverome2 def solverome2 defsolverome2 def solverome2 def solverome2 def |
| Ours | The story follows the Bronwyn family and their guests as they prepare for a tripto the Kindergarten, a remote island off the coast of Scotland. The family ismade up of Mr. Bronwyn, a philosopher, and his wife, Mrs. Bronwyn, who isknown for her beauty and her ability to surround and protect her family. Theguests include Kevin Emma, a young man who is studying for a fellowship, andCastor Kiki, a botanist who is also a widower. The story explores themes of love,marriage, and the search for meaning in life. It also touches on the idea of theother, or the person who is different from oneself and cannot be understood orjudged by conventional standards. Ultimately, the story is about the search fortruth and the desire to connect with others on a deeper level. |
分析:
- MInference (Ours) 的高质量摘要: MInference 生成的摘要详细且连贯,与原文的叙述和主题高度一致,例如对布朗温家族旅行准备、人物刻画以及爱情、婚姻、意义探索等主题的阐述。这表明 MInference 能够有效理解长上下文并生成高质量的文本。
- 基线方法的局限性:
StreamingLLM尽管生成了看似连贯的文本,但引入了原文中不存在的虚假信息(如“船旅行”、“渔民”、“男人溺水”),导致事实性错误。StreamingLLM w/ dilated和w/ strided生成的摘要则高度不连贯,充满了重复和无意义的字符,完全失败。Ours w/ static版本也生成了重复的代码片段,显示出其在处理动态文本任务时的失效。
- 结论: MInference 在文本生成任务中展现出卓越的能力,能够产生与人类生成文本相媲美的高质量摘要。
表 10: KV 检索任务生成结果对比 以下是原文 Table 10: Comparison of generated results by different methods using LLaMA-3-8B-Instruct-262K in KV Retrieval task . 的结果:
| Methods | Summary |
| Original | 1. Ground Truth: 5e6b7b90-710d-4953-9b18-3e96b2cadbf2Prediction: "5e6b7b90-710d-4953-9b18-3e96b2cadbf2"2. Ground Truth: f1ec3dd8-b6e7-4af5-a589-fda0b7693f93Prediction: "f1ec3dd8-b6e7-4af5-a589-fda0b7693f93" (a string of length 32). |
| StreamingLLM | 1. Ground Truth: 5e6b7b90-710d-4953-9b18-3e96b2cadbf2Prediction: 2b114db0-d87e-42d2-9b4c-0b1f115976ad.assistant2. Ground Truth: bbf3fa73-6217-4e6e-923c-8349bd286c3dPrediction: "d829ce18-6339-4f77-8c04-31fc7ec33619". |
| StreamingLLMw/ dilated | 1. Ground Truth: 5e6b7b90-710d-4953-9b18-3e96b2cadbf2Predictionl2. Ground Truth: f1ec3dd8-b6e7-4af5-a589-fda0b7693f93Prediction: "def solverome2 def solverome2 def solverome2 def solverome2 defsolverome2 def solverome2 def solverome2 def solverome2 def solverome2 defsolverome2 |
| StreamingLLMw/ strided | 1. Ground Truth: 5e6b7b90-710d-4953-9b18-3e96b2cadbf2Prediction: "def solverome2 def solverome2 def solverome2 def solverome2 defsolverome22. Ground Truth: f1ec3dd8-b6e7-4af5-a589-fda0b7693f93Prediction: "((((((((((((((((((((((((( |
| Ours w/ static | 1. Ground Truth: 5e6b7b90-710d-4953-9b18-3e96b2cadbf2Prediction: "def solverome2 def solverome2 def solverome2 def solverome2 def2. Ground Truth: f1ec3dd8-b6e7-4af5-a589-fda0b7693f93Prediction: "def solverome2 def solverome2 def solverome2 def solverome2 def |
| Ours | 1. Ground Truth: 5e6b7b90-710d-4953-9b18-3e96b2cadbf2Prediction: "5e6b7b90-710d-4953-9b18-3e96b2cadbf2"2. Ground Truth: f1ec3dd8-b6e7-4af5-a589-fda0b7693f93Prediction: "f1ec3dd8-b6e7-4af5-a589-fda0b7693f93" (a string of length 32). |
分析:
- MInference (Ours) 的完美检索: MInference 能够完美地检索并预测出与
Ground Truth(真实标注数据) 完全一致的键值对。这证明了其在要求精确匹配的检索任务中的卓越能力。 - 基线方法的失败:
StreamingLLM尽管生成了看似正确的格式,但预测的键值对与Ground Truth不符,属于事实性错误。StreamingLLM w/ dilated和w/ strided以及Ours w/ static均表现为重复的字符或无意义的字符串,完全无法完成检索任务。
- 结论: MInference 在 KV 检索任务中表现出极高的准确性和鲁棒性,远超其他基线方法,再次强调了其动态稀疏策略的有效性。
6.2. 消融实验/参数分析
论文进行了多组消融实验,以验证 MInference 各组件(尤其是动态稀疏策略和不同模式)的贡献。
6.2.1. 动态与静态索引的对比
Ours w/ static的性能: 在 Table 2 (InfiniteBench) 和 Table 4 (InfiniteBench 消融) 中,Ours w/ static的平均性能显著低于完整版 MInference (Ours)。尤其是在高度动态的任务如Retr.KV(键值检索) 中,其准确率几乎降为零(例如 LLaMA-3-8B 模型上为 0.2%,Yi-9B-200K 模型上为 1.0%)。在 Table 9 (摘要) 和 Table 10 (KV 检索) 的案例分析中,Ours w/ static也生成了重复且无意义的文本。- 结论: 这项结果有力地证明了 MInference 中动态构建稀疏索引的必要性和有效性。静态索引无法适应注意力模式的动态变化,导致在需要精准信息捕捉的任务中性能严重退化。
6.2.2. 不同稀疏模式的贡献
以下是原文 Table 4: Performance of different ablation methods using LLaMA-3-8B-Instruct-262K on InfiniteBench . 的结果:
| Methods | En.Sum | En.QA | En.MC | En.Dia | Zh.QA | Code.Debug | Math.Find | Retr.PassKey | Retr.Num | Retr.KV| | Avg. |
| Ours | 20.5 | 12.9 | 65.9 | 7.5 | 12.5 | 22.3 | 33.1 | 100.0 | 100.0 | 12.8 | 38.8 |
| Ours w/ only block-sparse | 12.4 | 3.4 | 5.7 | 6.0 | 3.1 | 12.2 | 24.0 | 59.5 | 60.3 | 0.0 | 18.7 |
| Ours w/ only vertical-slash | 19.6 | 12.0 | 62.1 | 9.5 | 11.7 | 21.6 | 29.1 | 100.0 | 100.0 | 5.0 | 37.1 |
分析:
Ours w/ only A-shape(等同于StreamingLLM): 从 Table 2 可以看出,StreamingLLM的平均得分远低于 MInference (Ours)。这表明仅使用 A形模式(即全局词元和局部窗口)只能捕获局部信息,无法满足长上下文 LLM 在各种任务中对更复杂和动态注意力模式的需求。Ours w/ only block-sparse: 仅使用块稀疏 (Block-Sparse) 模式导致了显著的性能下降,平均得分从 38.8 降至 18.7。尤其是在En.MC、Zh.QA和Retr.KV等任务中表现极差。这说明尽管块稀疏模式具有很高的动态性,但它自身不足以全面捕获所有关键注意力模式。Ours w/ only vertical-slash: 仅使用垂直斜杠 (Vertical-Slash) 模式在保持性能方面表现较好,平均得分达到 37.1,接近完整版 MInference 的 38.8。这表明垂直斜杠模式在动态性和类似于StreamingLLM的模式之间取得了良好的平衡,能够捕捉到许多关键的注意力信息。然而,在高度动态的任务如Retr.KV中,性能仍低于完整版 MInference (5.0 vs 12.8)。- 结论: 完整的 MInference 框架通过结合 A形、垂直斜杠和块稀疏这三种模式,并在核感知搜索的指导下为每个注意力头选择最佳模式,从而实现了最佳的性能。移除任何一种模式都会导致不同程度的性能下降,尤其是在需要捕捉动态注意力模式的任务中。
6.2.3. 垂直线与斜杠线的独立贡献
以下是原文 Table 8: Performance of different ablation methods using LLaMA-3-8B-Instruct-262K on InfiniteBench . 的结果:
| Methods | En.Sum | En.QA | En.MC | En.Dia | Zh.QA | Code.Debug | Math.Find | Retr.PassKey | Retr.Num | Retr.KV |
| Ours | 20.5 | 12.9 | 65.9 | 7.5 | 12.5 | 22.3 | 33.1 | 100.0 | 100.0 | 12.8 |
| Ours w/ only vertical | 13.7 | 6.2 | 30.1 | 2.0 | 6.5 | 22.3 | 7.9 | 33.1 | 65.4 | 52.7 |
| Ours w/ only slash | 18.4 | 11.5 | 60.1 | 3.0 | 11.4 | 22.1 | 28.4 | 100.0 | 100.0 | 4.2 |
分析:
Ours w/ only vertical: 仅使用垂直线会导致显著的性能下降,尤其在检索任务中表现很差(例如Retr.KV仅为 0.0%,在 Table 8 中显示为 52.7%, 这是与 Table 4 的Ours w/ only block-sparse0.0%的对比,这里的Retr.KV结果为 0.0%,与原始论文中描述的“similar to only using block-sparse”相符)。这表明单独的垂直线不足以捕获长上下文中的所有关键信息。Ours w/ only slash: 仅使用斜杠线能够保留大部分性能,但在高度动态的任务如Retr.KV中,性能仍进一步下降(平均性能下降 2.9%)。这表明斜杠线在捕获周期性或固定间隔信息方面很重要,但垂直线(捕捉特定词元)在某些场景下同样不可或缺。- 结论: 垂直斜杠模式的完整性(同时包含垂直线和斜杠线)对于保持高性能至关重要,特别是对于那些需要同时关注特定词元和周期性模式的任务。
6.3. 稀疏性在内核中的分布 (Sparsity in Kernel Distribution)
下图(原文 Figure 12)展示了在不同上下文窗口下,三种模式在实际内核计算过程中的稀疏性分布:
分析:
- 高稀疏率: 当上下文窗口超过 200K 词元时,所有三种模式的实际稀疏性都超过 90%。这意味着绝大多数注意力计算都可以被跳过。
- 理论加速比: 考虑到 20% 的索引构建开销,90% 的稀疏率仍能确保内核实现超过 8 倍的加速。
- 超长上下文的更高稀疏性: 当上下文窗口超过 500K 词元时,稀疏性相对于
FlashAttention超过 95%,理论上可实现超过 15 倍的加速。这说明 MInference 在处理极长上下文时效率更高。
6.4. 模式分布 (Pattern Distribution)
下图(原文 Figure 11)展示了通过核感知搜索获得的最佳注意力头配置的分布:
分析:
- 垂直斜杠模式的主导地位: 大部分注意力头(超过 90%)被分配为垂直斜杠模式。这表明在长上下文 LLM 中,关注特定词元和固定间隔词元的模式是普遍且重要的。
- 其他模式的战略性分布: 块稀疏模式主要分布在一些中后期层,而 A形模式则主要出现在中间层。这表明不同模式在模型中扮演着不同的角色,共同捕捉复杂的注意力分布。
- 泛化性: 尽管搜索是基于参考示例进行的,但相同的配置在 LLaMA-3 的两个版本(262K 和 1M)上都表现良好,尤其是在
Needle In A Haystack任务中取得了近乎完美的结果。这证明了所发现的最佳稀疏模式配置具有良好的泛化性。
7. 总结与思考
7.1. 结论总结
本文提出了 MInference (Milliontokens Inference),一个用于加速长上下文 LLM 预填充阶段 (pre-filling stage) 的高效动态稀疏注意力 (dynamic sparse attention) 框架。研究人员通过深入分析,识别出长上下文注意力矩阵中三种独特的空间稀疏模式:A形 (A-shape)、垂直斜杠 (Vertical-Slash) 和块稀疏 (Block-Sparse)。MInference 采用了一种创新的核感知最优稀疏模式搜索 (Kernel-Aware Optimal Sparse Pattern Search) 方法,离线 (offline) 为每个注意力头 (attention head) 分配最佳模式及其参数。在推理 (inference) 阶段,它通过轻量级的在线近似方法动态 (dynamically) 构建稀疏索引 (sparse indices),并利用为这些模式专门优化的 GPU 内核 (kernel) 执行稀疏注意力计算。
实验结果表明,MInference 能够显著降低长上下文 LLM 的推理延迟,在单块 A100 GPU 上,对于 1M 词元 (token) 提示词,预填充延迟从 30 分钟减少到 3 分钟,实现了高达 10 倍的加速。同时,MInference 在 InfiniteBench、RULER、PG-19 和 Needle In A Haystack 等一系列长上下文基准测试上,在保持甚至略微提高准确性 (accuracy) 的前提下,优于所有现有基线方法。更重要的是,MInference 作为一个即插即用 (plug-and-play) 的解决方案,无需修改现有 LLM 的预训练设置 (pre-training setup) 或进行额外微调 (fine-tuning)。此外,MInference 还被证明与 KV 缓存压缩方法 (KV cache compression methods) 兼容,进一步提升了其实用价值。研究发现,类似的动态稀疏注意力模式也存在于多模态 LLM (multi-modal LLM) 和编码器-解码器 LLM (encoder-decoder LLM) 中,这预示着 MInference 在更广泛领域的应用潜力。
7.2. 局限性与未来工作
7.2.1. 局限性
- 短上下文下的索引构建开销: 论文指出,当上下文长度减少时(例如 10K 词元),动态索引构建的时间开销变得更加显著,可能占总时间的 30%,导致总的端到端延迟接近 FlashAttention。这意味着 MInference 在短到中等上下文长度下的加速优势会减弱,甚至可能不如密集注意力。
- 稀疏率与性能的权衡: 论文提及,使用更高的稀疏率 (sparsity rate) 可能会导致模型性能显著下降。这表明稀疏度是一个敏感的超参数,需要仔细调优以平衡效率和准确性,过高的稀疏度会损害模型理解能力。
7.2.2. 未来工作
- 扩展到多模态 LLM: 论文发现多模态 LLM [WWL+24] 中也存在类似的垂直线 (vertical lines) 和斜杠线 (slash lines) 稀疏模式。未来工作可以探索将 MInference 应用于多模态 LLM 的预填充阶段推理加速。
- 扩展到编码器-解码器 LLM: 编码器-解码器 LLM [RSR+20] 的注意力模式中也观察到类似的稀疏性。MInference 在其预填充阶段的加速应用同样具有潜力。
- 进一步优化索引构建开销: 针对短上下文场景下动态索引构建的开销问题,可以探索更轻量级或自适应的索引构建策略。
7.3. 个人启发与批判
7.3.1. 个人启发
- 深入分析注意力模式的价值: 这项工作最令人启发的一点是,它并没有简单地将注意力稀疏性视为一个整体概念,而是通过详尽的分析,识别并分类出长上下文注意力矩阵中的具体空间模式。这种“从现象到模式”的归纳能力,是推动领域进步的关键。
- 核感知优化的重要性: “核感知 (Kernel-Aware)”的搜索策略是其成功的重要因素。它提醒我们,在实际硬件上实现高性能,不仅仅要考虑理论上的 FLOPs 减少,更要考虑如何将计算有效映射到处理器架构上,这在深度学习系统优化中至关重要。
- 即插即用 (Plug-and-play) 的巨大吸引力: MInference 无需对模型进行再训练或微调,可以直接应用于现有 LLM。这极大地降低了部署成本和风险,使其成为一个具有很高实用价值的解决方案,尤其对于那些已经投入大量资源训练和部署大型 LLM 的组织。
- 动态稀疏与静态稀疏的平衡艺术: 论文通过消融实验明确展示了动态稀疏策略对于保持准确性的必要性,特别是对于检索等敏感任务。同时,通过三种模式的组合,它巧妙地平衡了不同稀疏模式的特点,兼顾了局部稳定性(A形)和动态全局关注(垂直斜杠、块稀疏)。
- 跨领域应用潜力: 发现类似的稀疏模式也存在于多模态和编码器-解码器 LLM 中,这表明 MInference 的核心思想可能具有更广泛的适用性,为其他 Transformer 架构的加速提供了新的思路。
7.3.2. 批判与潜在改进
- 动态索引构建的鲁棒性与泛化性: 尽管论文提到索引构建开销在长上下文下比例较低,但在中短上下文下却成为瓶颈。此外,模式识别和索引构建的近似方法(例如
Vertical-Slash使用 个查询向量,Block-Sparse使用均值池化)是否在所有复杂场景下都能准确捕获最优稀疏模式?如果输入提示词的模式与用于离线搜索的参考示例差异较大,性能是否会下降?这种近似方法对于极端情况的鲁棒性需要更深入的探究。 - 超参数敏感性: 垂直斜杠模式中的 (垂直线和斜杠线的数量) 和块稀疏模式中的 (块数量) 是关键超参数。虽然有核感知搜索,但这些参数的确定是否对特定任务或数据分布敏感?是否存在更自适应的机制,能够在推理时根据输入内容动态调整这些参数,以避免人工调优或潜在的次优选择?
- 解码阶段的未解决问题: MInference 主要解决了预填充阶段的问题,但解码阶段的 KV 缓存存储和访问效率仍然是 LLM 推理的另一个主要瓶颈。虽然论文提到与 KV 缓存压缩方法兼容,但这只是并行的优化,而非 MInference 自身对解码阶段的贡献。未来的研究可以探索 MInference 发现的稀疏模式是否能为解码阶段的 KV 缓存管理提供新的启示。
- 多 GPU / 分布式推理场景的考虑: 论文的性能评估主要在单 A100 GPU 上进行。在实际大规模部署中,LLM 常常运行在多 GPU 或分布式环境中。在这种情况下,动态稀疏索引的生成、分发和同步可能会引入新的通信开销和复杂性,这需要进一步的系统级优化和评估。
- 模型架构无关性: 虽然 MInference 声称无需修改预训练模型,但其优化是针对 Transformer 架构中的注意力层。对于未来可能出现的非 Transformer 架构(如 Mamba 系列),MInference 的方法是否仍然适用或需要重大修改,是一个值得探讨的问题。
相似论文推荐
基于向量语义检索推荐的相关论文。