论文状态:已完成

Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

发表:2024/06/16
原文链接PDF 下载
价格:0.100000
价格:0.100000
已有 1 人读过
本分析由 AI 生成,可能不完全准确,请以原文为准。

TL;DR 精炼摘要

随着长上下文大型语言模型(LLMs)的需求增加,推理效率面临挑战。论文提出了Quest,一个查询感知的KV缓存选择算法,通过跟踪最小和最大关键值,从而仅加载最关键的KV缓存,显著提升自注意力计算速度,最高可达7.03倍加速,同时在长依赖任务上保持良好准确性。

摘要

As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128K or 1M tokens are becoming increasingly prevalent. However, long-context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query. To this end, we propose Quest, a query-aware KV cache selection algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 2.23x self-attention speedup, which reduces inference latency by 7.03x while performing well on tasks with long dependencies with negligible accuracy loss. Code is available at http://github.com/mit-han-lab/Quest .

思维导图

论文精读

中文精读

1. 论文基本信息

1.1. 标题

Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

1.2. 作者

论文由以下作者完成:

  • Jiaming Tang * (1, 2)

  • Yilong Zhao * (1, 3)

  • Kan Zhu (3)

  • Guangxuan Xiao (2)

  • Baris Kasikci (3)

  • Song Han (2, 4)

    其中:

  • 1 代表共同第一作者。

  • 2 代表麻省理工学院 (MIT)。

  • 3 代表密歇根大学 (University of Michigan)。

  • 4 代表 MIT-IBM Watson AI Lab。

1.3. 发表期刊/会议

论文作为预印本 (preprint) 发布在 arXiv 上。

1.4. 发表年份

2024年。

1.5. 摘要

随着对长上下文 (long-context) 大型语言模型 (LLMs) 需求的增长,具有高达128K或1M词元 (tokens) 上下文窗口的模型日益普及。然而,长上下文 LLM 推理 (inference) 具有挑战性,因为推理速度随着序列长度的增加而显著下降。这种减速主要由自注意力 (self-attention) 过程中加载大量键值缓存 (KV cache) 引起。以往的研究表明,一小部分关键词元 (critical tokens) 会主导注意力结果。然而,作者观察到词元的关键性高度依赖于查询 (query)。为此,论文提出了 Quest,一种查询感知 (query-aware) 的 KV cache 选择算法。Quest 通过跟踪 KV cache 页面 (KV cache pages) 中键 (Key) 值的最小值和最大值,并利用查询向量 (Query vectors) 估计给定页面的关键性。通过仅加载 Top-K 最关键的 KV cache 页面进行注意力计算,Quest 在不牺牲准确性的前提下显著加速了自注意力。实验结果表明,Quest 可以实现高达 7.03×7.03 \times 的自注意力加速,从而将推理延迟降低 2.23×2.23 \times,同时在长依赖任务上表现良好,且准确性损失可忽略不计。

1.6. 原文链接

2. 整体概括

2.1. 研究背景与动机

大型语言模型 (LLMs) 的长上下文能力需求: 随着 LLM 在多轮对话、长文档查询等复杂场景中的广泛应用,对更大上下文窗口的需求日益增长。当前 LLM 的上下文窗口已从早期模型的几千个词元 (tokens) 扩展到数十万甚至百万级别,如 GPT-4 Turbo 支持 128k 词元,Claude-2 支持 200k 词元,开源模型如 Yarn-Llama-2 也达到了 128k 词元。

长上下文推理的挑战: 尽管上下文窗口不断扩大,但长上下文 LLM 的推理效率却面临严峻挑战。主要障碍在于:

  1. 自回归 (auto-regressive) 特性: LLM 每次生成一个新词元时,都需要读取并处理之前所有词元的键值缓存 (KV cache)。
  2. KV 缓存规模巨大: 随着序列长度的增加,KV cache 的大小呈线性增长。例如,一个 32K 上下文长度的 Llama 7B 模型,其 KV cache 可能占用高达 16GB 的内存空间。加载这些海量数据可能占到单次词元生成延迟的 50% 以上。这种巨大的内存带宽 (memory bandwidth) 瓶颈严重限制了 LLM 的推理速度和整体吞吐量。

现有研究的不足与论文的切入点:

  • 稀疏性观察: 已有研究指出,在自注意力 (self-attention) 机制中,并非所有 KV cache 中的词元都同等重要,通常只有一小部分“关键词元 (critical tokens)”对注意力结果起主导作用。这一观察为通过稀疏化 KV cache 来提高推理效率提供了理论基础。
  • 现有稀疏化方法的局限: 大多数现有 KV cache 优化方法(如 H2OTOVAStreamingLLM)主要依赖于历史信息或固定上下文窗口来决定哪些词元应该被保留或丢弃。然而,作者观察到词元的“关键性”是动态变化的,并且高度依赖于当前的“查询词元 (query token)”。例如,一个词元在某个时间步可能被认为是无关紧要的,但在后续某个特定查询下,它可能变得至关重要。如果关键词元被过早地永久丢弃,将导致模型准确性下降,尤其是在需要处理长距离依赖关系的任务中。
  • 论文的创新思路: 针对上述现有方法的局限性,Quest 提出了一种新颖的“查询感知稀疏性 (query-aware sparsity)”方法。该方法不再永久丢弃 KV cache 中的词元,而是根据当前的查询向量,动态地估计 KV cache 中各个页面的关键性,并只加载并关注最关键的 Top-K 个页面进行注意力计算,从而大幅减少内存移动量并加速推理,同时避免了关键信息丢失。

2.2. 核心贡献/主要发现

论文的主要贡献包括:

  • 对自注意力机制的分析: 论文深入分析了自注意力机制,并明确指出了查询感知稀疏性的重要性。通过实验观察,作者证明了词元的关键性是动态变化的,并且与当前的查询向量高度相关。
  • 提出了 Quest 算法: 设计并实现了一种高效且准确的 KV cache 加速算法 Quest。该算法通过为每个 KV cache 页面维护键 (Key) 向量在各个维度上的最大值和最小值作为元数据,并利用当前查询向量来动态估计每个页面的关键性。通过这种查询感知的方法,Quest 能够高效地识别并选择最关键的词元进行注意力计算。该算法结合了专门的算子 (operator) 设计和底层实现。
  • 全面的评估:Quest 进行了广泛而全面的准确性 (accuracy) 和效率 (efficiency) 评估。
    • 效率方面: 实验结果显示,在 32K 词元序列长度下,Quest 可以实现高达 7.03×7.03 \times 的自注意力 (self-attention) 加速。在端到端 (end-to-end) 推理延迟方面,尤其是在使用 4 位量化 (4-bit quantization) 时,可以实现 2.23×2.23 \times 的速度提升。
    • 准确性方面: QuestPG19 语言建模 (language modeling)、passkey retrieval 任务和 LongBench 的多个长上下文任务中表现出色,在实现显著 KV cache 稀疏化的同时,能够保持与全 KV cache 模型相当的准确性,且准确性损失可忽略不计。与现有基线方法相比,Quest 在相同的准确性目标下,自注意力延迟最高可减少 4.5×4.5 \times

3. 预备知识与相关工作

3.1. 基础概念

3.1.1. 大型语言模型 (Large Language Models, LLMs)

LLM 是一类基于深度学习的计算模型,通常采用 Transformer 架构,通过在海量文本数据上进行预训练 (pre-training) 学习语言的统计规律和知识。它们能够理解、生成和处理人类语言,执行如文本补全、翻译、问答等多种任务。

3.1.2. 上下文窗口 (Context Window)

上下文窗口指的是 LLM 在生成下一个词元时能够考虑的先前词元的最大数量。上下文窗口越大,模型能够理解和利用的信息就越多,这对于处理长文档、多轮对话或需要长距离依赖的任务至关重要。

3.1.3. 自注意力机制 (Self-Attention Mechanism)

Self-AttentionTransformer 架构中的核心组成部分,它允许模型在处理序列中的每个词元时,对序列中的其他所有词元进行加权,从而捕捉词元之间的关系。 其基本计算公式为: Attention(Q,K,V)=softmax(QKTdk)V \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 其中:

  • QQ (Query) 是查询矩阵,代表当前词元的信息,用于查询其他词元。
  • KK (Key) 是键矩阵,代表序列中所有词元的信息,用于被查询。
  • VV (Value) 是值矩阵,代表序列中所有词元的内容信息,根据注意力权重进行加权求和。
  • dkd_k 是键向量的维度,用于缩放点积,防止在维度较高时点积结果过大,导致 softmax 函数梯度过小。
  • QKTQK^T 计算了查询与所有键的点积,得到注意力得分 (attention scores)。
  • softmax\mathrm{softmax} 函数将注意力得分归一化为概率分布,表示每个词元对当前词元的“注意力”程度。
  • 最终,注意力输出是所有值向量的加权和。

3.1.4. 键值缓存 (KV Cache)

LLM 的自回归 (auto-regressive) 推理过程中,为了避免重复计算,模型会将之前生成词元的键 (Key) 和值 (Value) 向量存储起来。这些存储起来的 KKVV 向量的集合就称为 KV cache。在生成下一个词元时,只需要计算当前词元的查询 (Query) 向量,然后与 KV cache 中的所有 KK 向量进行注意力计算,再与对应的 VV 向量加权求和。 挑战: 随着序列长度的增加,KV cache 的大小会迅速增长,导致内存占用和内存加载时间成为推理速度的瓶颈。

3.1.5. PageAttention

PageAttention 是一种 KV cache 管理机制,它将 KV cache 存储在固定大小的“页面 (pages)”中,类似于操作系统中的内存分页。这种机制允许更高效地管理和访问 KV cache,例如支持非连续的 KV cache 存储,从而提高内存利用率和灵活性。Quest 算法也采用了 KV cache 的页面粒度管理。

3.1.6. 困惑度 (Perplexity, PPL)

困惑度是衡量语言模型性能的常用指标,尤其在语言建模任务中。它量化了一个概率分布预测一个样本的程度,通常用来评估模型预测下一个词元的准确性。 概念定义: 困惑度越低,说明模型对给定文本的预测能力越强,生成的文本越流畅、越符合语言习惯。直观上,困惑度可以理解为模型在预测下一个词元时“有多困惑”,或者每个词元平均有多少个等可能的选择。 数学公式: 给定一个测试序列 W=(w1,w2,,wN)W = (w_1, w_2, \dots, w_N),其困惑度定义为: PPL(W)=(i=1N1P(wiw1,,wi1))1/N=exp(1Ni=1NlogP(wiw1,,wi1)) \mathrm{PPL}(W) = \left( \prod_{i=1}^{N} \frac{1}{P(w_i | w_1, \dots, w_{i-1})} \right)^{1/N} = \exp\left( - \frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \dots, w_{i-1}) \right) 符号解释:

  • WW: 测试序列。
  • NN: 测试序列中词元的数量。
  • wiw_i: 序列中的第 ii 个词元。
  • P(wiw1,,wi1)P(w_i | w_1, \dots, w_{i-1}): 模型根据前 i-1 个词元预测第 ii 个词元的概率。
  • \prod: 连乘符号。
  • exp\exp: 自然指数函数。
  • \sum: 求和符号。
  • log\log: 自然对数。

3.1.7. 召回率 (Recall Rate)

召回率是分类任务中常用的评估指标,尤其在信息检索或识别关键元素时。 概念定义:Quest 的语境中,召回率衡量的是算法识别出的 Top-K 关键词元中有多少比例确实是“真正”关键的词元(即在完整注意力计算中具有高注意力得分的词元)。高召回率意味着算法很少错过重要的词元。 数学公式: Recall=True PositivesTrue Positives+False Negatives \mathrm{Recall} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}} 符号解释:

  • True Positives (TP): 算法正确识别为关键的词元数量。
  • False Negatives (FN): 算法未能识别为关键但实际上是关键的词元数量(漏报)。

3.1.8. F1 分数 (F1 Score)

F1 分数是精确率 (Precision) 和召回率 (Recall) 的调和平均值,用于评估分类模型的性能,尤其是在类别不平衡的情况下。它综合考虑了模型的查准率和查全率。 概念定义: F1 分数取值范围在 0 到 1 之间,1 表示完美的精确率和召回率。F1 分数越高,表示模型的分类性能越好。 数学公式: F1=2PrecisionRecallPrecision+Recall \mathrm{F1} = 2 \cdot \frac{\mathrm{Precision} \cdot \mathrm{Recall}}{\mathrm{Precision} + \mathrm{Recall}} 其中,精确率 (Precision) 的计算公式为: Precision=True PositivesTrue Positives+False Positives \mathrm{Precision} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Positives}} 符号解释:

  • Precision: 精确率,表示被预测为关键的词元中有多少比例是真正关键的。
  • Recall: 召回率,表示真正关键的词元中有多少比例被模型识别出来。
  • True Positives (TP): 算法正确识别为关键的词元数量。
  • False Positives (FP): 算法错误识别为关键但实际上不是关键的词元数量(误报)。
  • False Negatives (FN): 算法未能识别为关键但实际上是关键的词元数量(漏报)。

3.2. 前人工作

论文将 Quest 与多种现有 KV cache 优化策略进行了比较:

  • H2O (Heavy-hitter Oracle for Efficient Generative Inference of Large Language Models) (Zhang et al., 2023b): 这是一种 KV cache 驱逐算法,通过保留基于历史注意力分数总和的少量重要 KV cache。其主要思想是,历史上获得高注意力分数的词元未来也可能很重要。
  • TOVA (Transformers Are Multi-State RNNs) (Oren et al., 2024): 该方法基于一个简化策略,仅根据当前查询来决定哪些词元将被永久丢弃。
  • StreamingLLM (Efficient Streaming Language Models with Attention Sinks) (Xiao et al., 2023): 该方法通过引入“注意力汇 (attention sinks)”和有限的 KV cache 来处理无限长的文本。它主要关注如何维护一个有限大小的上下文窗口以适应流式输入。
  • SparQ (SparQ Attention: Bandwidth-Efficient LLM Inference) (Ribar et al., 2023): 通过通道剪枝 (channel pruning) 计算近似注意力分数,并据此选择重要词元。论文指出,该方法在长依赖任务上尚未被广泛验证,且通道级稀疏性可能带来实际加速的挑战。
  • FlashInfer (Ye et al., 2024): 一个专门用于 LLM 推理的高性能 CUDA 内核 (kernel) 库,提供了高效的注意力计算实现,如 FlashAttentionQuest 的效率评估正是基于 FlashInfer 实现的。

3.3. 技术演进

LLM 的发展趋势是上下文窗口越来越大,从几千到上百万词元。这使得 KV cache 的规模成为推理效率的瓶颈。早期的优化方法主要集中在:

  1. 位置编码扩展: 例如 RoPE (Rotary Position Embeddings) 及其各种缩放方法(如 Yarn-Llama-2LongChat),以支持更长的序列。

  2. KV cache 压缩/驱逐: 意识到 KV cache 的冗余性后,研究开始探索如何减少其大小。这包括识别并丢弃不重要的词元(如 H2OTOVAStreamingLLM)。

    本文 Quest 处在 KV cache 优化这一技术脉络中,但它认识到现有驱逐算法的局限性:词元关键性是动态的,且与查询相关。因此,Quest 的创新在于引入了“查询感知”的机制,不再永久丢弃 KV cache 中的词元,而是动态选择,从而在保证准确性的前提下实现更大幅度的效率提升。

3.4. 差异化分析

Quest 与上述相关工作的主要区别和创新点在于:

  1. 查询感知 (Query-Aware) vs. 查询无关 (Query-Agnostic)/历史依赖:
    • H2OTOVA 等方法主要依赖历史注意力分数或当前状态来决定哪些词元可以被“丢弃”。一旦丢弃,这些词元就无法再被关注,即使它们在后续查询中变得重要。
    • StreamingLLM 关注固定窗口内的词元,超出窗口的词元基本被忽略。
    • Quest 则强调词元的关键性是动态且依赖于当前查询的。如图 2 所示,一个词元(如“B”)在不同查询(“D”与“is”)下其重要性可能截然不同。Quest 不丢弃 KV cache 中的任何词元,而是动态地、查询感知地选择最相关的词元页面进行注意力计算,从而避免了关键信息丢失。
  2. 页面粒度管理和近似计算:
    • Quest 结合了 PageAttention 的页面粒度管理,以降低元数据 (metadata) 维护的开销。
    • 它通过存储 KV cache 页面中 Key 向量的最小/最大值来高效地估计每个页面的注意力得分上界,从而快速选择 Top-K 关键页面。这种近似计算在保证准确性的同时提供了极高的效率。
  3. 实际加速而非仅理论稀疏性:
    • SparQ 提出了通道剪枝,但论文指出其在长依赖任务上的验证不足,且通道级稀疏性在实际加速方面可能存在挑战。
    • Quest 通过与 FlashInfer 集成并实现专门的 CUDA 内核,展示了显著的实际自注意力加速和端到端延迟降低。

4. 方法论

4.1. 方法原理

Quest 的核心思想是利用查询感知稀疏性 (query-aware sparsity) 来高效地进行长上下文 LLM 推理。它基于两个关键观察:

  1. 自注意力的高稀疏性:LLM 的自注意力计算中,只有一小部分 KV cache 词元是“关键的 (critical)”,对最终的注意力结果起主导作用。因此,如果能有效地识别这些关键词元,只对它们进行注意力计算,就可以大幅减少内存移动和计算量。

  2. 词元关键性的动态性与查询相关性: 词元的关键性不是固定不变的,而是随着当前查询向量 (Query vector) 的变化而动态变化的。一个词元在某个时间点可能不重要,但在下一个时间点,当新的查询词元出现时,它可能变得至关重要(如图 2 所示)。因此,查询无关 (query-agnostic) 或基于历史信息的 KV cache 驱逐策略可能错误地丢弃未来关键的词元,导致准确性下降。

    为了解决这些问题,Quest 提出了一个查询感知 KV cache 选择算法:

  • 不丢弃 KV cache Quest 不会永久丢弃任何 KV cache 中的词元,从而避免了关键信息丢失的风险。
  • 页面粒度管理: 为了管理和减少估计关键性的开销,Quest 采用 PageAttention 的思想,以页面 (page) 为粒度管理 KV cache
  • 高效的关键性估计: 对于每个 KV cache 页面,Quest 维护其所有 Key 向量在每个维度上的最小值 (minimal) 和最大值 (maximal) 作为元数据。在推理时,给定当前查询向量 QQQuest 利用这些元数据快速估计每个页面的注意力得分上界 (upper bound of attention weights),从而评估其关键性。
  • Top-K 选择与稀疏注意力: Quest 选择具有最高关键性得分的 Top-K 个页面,仅对这些页面执行实际的自注意力计算。这大大减少了需要加载的 KV cache 数据量,从而加速推理。

4.2. 核心方法详解

Quest 的工作流程可以分为两个主要阶段:元数据维护 (Metadata Maintenance)关键性估计与稀疏注意力执行 (Criticality Estimation and Sparse Self-Attention Execution)

4.2.1. 元数据维护

当新的词元被插入到 KV cache 中时,Quest 需要更新其所属页面的元数据。对于每个 KV cache 页面,Quest 维护一个最小键向量 mm 和一个最大键向量 MM。这些向量的维度与原始 Key 向量的维度相同。

以下是原文 Algorithm 1 中“When inserting new token to KV cache”部分的伪代码:

Algorithm 1 Token Criticality Estimation

When inserting new token to KV cache:
Input: Key vector K, Dimension of hidden states dim, Current maximal vector M_i, Current minimal vector m_i
for i = 1 to dim do
    M_i = max(M_i, k_i)
    m_i = min(m_i, k_i)
end for

符号解释:

  • KK: 新插入的 Key 向量。
  • kik_i: Key 向量 KK 的第 ii 个分量。
  • dim: Key 向量的维度(即隐藏状态维度)。
  • MiM_i: 当前页面中所有 Key 向量在第 ii 个维度上的最大值。
  • mim_i: 当前页面中所有 Key 向量在第 ii 个维度上的最小值。

流程: 当一个新词元及其 Key 向量 KK 被添加到 KV cache 中的某个页面时,Quest 会遍历 KK 的每个维度 ii (从 1 到 dim):

  • 将该维度上的 kik_i 与当前页面在该维度上的最大值 MiM_i 进行比较,并更新 M_i = \max(M_i, k_i)
  • 将该维度上的 kik_i 与当前页面在该维度上的最小值 mim_i 进行比较,并更新 m_i = \min(m_i, k_i)。 通过这种方式,每个页面的 MMmm 向量始终记录着该页面中所有 Key 向量在每个维度上的值范围。

4.2.2. 关键性估计与稀疏注意力执行

在每次进行自注意力计算时(特别是在解码阶段),给定当前查询向量 QQQuest 会利用预先计算的元数据 MMmm 来估计每个 KV cache 页面的关键性。

以下是原文 Algorithm 1 中“When perform self-attention”部分的伪代码:

Algorithm 1 Token Criticality Estimation
...

When perform self-attention:
Input: Query vector Q, Dimension of hidden states dim, Current maximal vector M_i, Current minimal vector m_i

Initialize score = 0.
for i = 1 to dim do
    score += MAX(q_i * max, q_i * min)
end for

符号解释:

  • QQ: 当前查询向量。
  • qiq_i: 查询向量 QQ 的第 ii 个分量。
  • dim: Key 向量的维度。
  • max: 当前页面中 Key 向量在第 ii 个维度上的最大值 MiM_i
  • min: 当前页面中 Key 向量在第 ii 个维度上的最小值 mim_i
  • score: 估计的页面关键性得分。

流程: 对于每个 KV cache 页面,Quest 执行以下操作来估计其关键性得分:

  1. 初始化得分: 将页面的 score 初始化为 0。

  2. 遍历维度并计算上界: 对于 Key 向量的每个维度 ii (从 1 到 dim):

    • 计算 qimaxq_i \cdot maxqiminq_i \cdot min
    • 取这两个乘积中的最大值:max(qimax,qimin)\max(q_i \cdot max, q_i \cdot min)
    • 将这个最大值累加到 score 中。 这个 score 代表了该页面中所有 Key 向量与当前 Query 向量的点积的上界直觉: 标准的注意力得分是 QK=iqikiQ \cdot K = \sum_i q_i \cdot k_i。为了在不知道 KK 确切值的情况下估计一个页面的最大可能注意力得分,Quest 利用了 minkimaxmin \le k_i \le max 的事实。
    • 如果 qi>0q_i > 0,那么 qikiq_i \cdot k_i 的最大可能值是 qimaxq_i \cdot max
    • 如果 qi<0q_i < 0,那么 qikiq_i \cdot k_i 的最大可能值是 qiminq_i \cdot min
    • 如果 qi=0q_i = 0,那么 qiki=0q_i \cdot k_i = 0。 因此,max(qimax,qimin)\max(q_i \cdot max, q_i \cdot min) 总是 qikiq_i \cdot k_i 在该维度上的最大可能值,无论 qiq_i 的符号如何。将这些维度上的最大可能值累加起来,就得到了页面内 Key 向量与 Query 向量点积的近似上界。这个上界越高,说明该页面包含与当前查询高度相关的 Key 向量的可能性越大,因此其关键性越高。
  3. Top-K 页面选择: 在计算出所有 KV cache 页面的关键性得分后,Quest 会选择得分最高的 Top-K 个页面。这里的 KK 是一个预设的常数超参数 (例如 128, 256)。

  4. 稀疏自注意力 (Sparse Self-Attention): 最后,Quest 仅对这 Top-K 个选定的页面中的 Key 向量和 Value 向量执行正常的自注意力计算。这意味着只有这些关键页面中的数据需要从内存中加载,从而大大减少了内存移动量。

4.2.3. 内存移动优化

Quest 通过利用查询感知稀疏性显著减少了自注意力过程中的内存移动。 假设:

  • 每个 Key 或 Value 向量的大小是 MM 字节。

  • KV cache 包含 LL 个词元。

  • 每个页面包含 SSKV 对 (页面大小)。

  • KV cache 大小是 2ML2 \cdot M \cdot L 字节。

    Quest 中:

  • 关键性估计阶段: 仅加载每个页面的最大值和最小值向量作为元数据。这大约是 2M(L/S)2 \cdot M \cdot (L/S) 字节。因为有 L/S 个页面,每个页面的元数据是 2M 字节 (一个最大值向量,一个最小值向量)。

  • 稀疏自注意力阶段: 仅加载 Top-K 个选定页面的所有 KV 对。这大约是 2MKS2 \cdot M \cdot K \cdot S 字节。

    因此,Quest 总共加载的 KV cache 数据量大约为 2M(L/S+KS)2 \cdot M \cdot (L/S + K \cdot S) 字节。 相比于加载整个 KV cache2ML2 \cdot M \cdot L 字节,Quest 加载的数据比例约为: 2M(L/S+KS)2ML=L/S+KSL=1S+KSL \frac{2 \cdot M \cdot (L/S + K \cdot S)}{2 \cdot M \cdot L} = \frac{L/S + K \cdot S}{L} = \frac{1}{S} + \frac{K \cdot S}{L} 这个表达式可以简化为: 1PageSize+KPageNum \frac { 1 } { \mathrm { P a g e \thinspace S i z e } } + \frac { K } { \mathrm { P a g e \thinspace N u m } } 其中 Page Size 是指每个页面包含的 KV 对数量 SSPage Num 是指总的页面数量 L/S

示例: 如果 Page Size 为 16 个 KV 对,上下文长度 LL 为 64K 词元,并且选择 Top-K=4K 词元(即 KS=4KK \cdot S = 4K),那么 Quest 将把内存加载量减少 8×8 \times。 这个内存加载的减少是与模型无关的,并且与现有的量化 (quantization) 机制兼容。

4.2.4. 分层应用策略

根据图 3 的分析,LongChat-7B 模型的前两层稀疏性较低(低于 10%),而后层稀疏性较高(大于 90%)。为了更好地保留模型准确性,Quest 及其所有基线方法都不应用于模型的前两层。这意味着前两层始终使用完整的 KV cache 进行自注意力计算,而从第三层开始才应用 Quest 算法进行稀疏化。这种策略确保了模型在早期阶段能够捕获到所有重要的上下文信息,同时在后续层中通过稀疏化提升效率。

4.3. Quest 系统架构

以下是原文 Figure 5 的 VLM 描述: VLM 描述: 该图像是一个示意图,展示了Quest算法在KV缓存选择中的工作流程。该流程分为两个阶段:第一阶段评估关键页面,包括元素乘积、每通道最大值和求和;第二阶段计算稀疏注意力。图中包含关键值、减少键、当前查询等信息,以及操作结果。

该图像是一个示意图,展示了Quest算法在KV缓存选择中的工作流程。该流程分为两个阶段:第一阶段评估关键页面,包括元素乘积、每通道最大值和求和;第二阶段计算稀疏注意力。图中包含关键值、减少键、当前查询等信息,以及操作结果。 该图像是一个示意图,展示了Quest算法在KV缓存选择中的工作流程。该流程分为两个阶段:第一阶段评估关键页面,包括元素乘积、每通道最大值和求和;第二阶段计算稀疏注意力。图中包含关键值、减少键、当前查询等信息,以及操作结果。

图 5 详细展示了 Quest 的系统架构:

  1. 输入 (Input): 接收当前查询向量 (Current Query) QQ 和完整的 KV cache

  2. 页面 KV 缓存 (Page KV Cache): 完整的 KV cache 被组织成若干个页面。每个页面 jj 都维护了其所有 Key 向量在每个维度上的最小值 mjm_j 和最大值 MjM_j 作为元数据 (Page Metadata)。

  3. 关键性估计 (Criticality Estimation): 对于每个页面 jjQuest 使用当前查询向量 QQ 和该页面的元数据 (mjm_j, MjM_j) 来计算一个关键性得分 (Score)。这个计算过程就是 Algorithm 1 中描述的,通过累加 max(qiMj,i,qimj,i)\max(q_i \cdot M_{j,i}, q_i \cdot m_{j,i}) 来估计注意力得分的上界。

  4. Top-K 页面选择 (Top-K Page Selection): 根据所有页面的关键性得分,Quest 选择得分最高的 Top-K 个页面。这些页面被认为是当前查询最相关的“关键页面”。

  5. 稀疏注意力 (Sparse Attention): 只有这 Top-K 个关键页面中的 Key 和 Value 向量被加载到计算单元中,用于执行实际的自注意力计算。这个阶段可以使用像 FlashAttention 这样高效的算子。

  6. 输出 (Output): 得到稀疏注意力计算的结果。

    整个流程的优势在于,它避免了加载和计算所有 KV cache 词元的注意力,而是通过一个轻量级的元数据计算和 Top-K 选择过程,动态地聚焦于最相关的上下文,从而在保持准确性的前提下大幅提高推理效率。

5. 实验设置

5.1. 数据集

论文使用了以下数据集来评估 Quest 的性能:

5.1.1. PG19

  • 来源: Rae et al., 2019
  • 特点: 包含 100 本书籍,平均每本书的长度约为 70K 词元 (tokens)。
  • 用途: 用于评估 LLM 的语言建模 (language modeling) 能力,特别是困惑度 (perplexity)。它是一个长文本数据集,能很好地测试模型在长序列上的上下文理解能力。

5.1.2. Passkey Retrieval Task (通行密钥检索任务)

  • 来源: Peng et al., 2023 (Yarn-Llama-2 的作者提出)
  • 特点: 这项任务旨在衡量模型从大量无意义文本中检索一个简单“通行密钥 (passkey)”的能力。通行密钥被放置在文本的不同深度位置。
  • 用途: 用于测试模型处理长距离依赖 (long-distance dependencies) 的能力。由于通行密钥可能出现在文本的非常靠前的位置,模型需要记住长时间的上下文信息才能正确检索。这对于 KV cache 优化算法来说是一个严峻的挑战,因为它们必须避免错误地丢弃包含通行密钥的词元。

5.1.3. LongBench

  • 来源: Bai et al., 2023
  • 特点: 一个双语 (bilingual)、多任务 (multitask) 的长上下文理解基准。论文选择了其中六个数据集进行评估,涵盖了不同的任务类型和文本长度。
  • 用途: 用于验证 Quest 在通用长上下文任务上的性能和泛化能力。
  • 具体子数据集:
    • 单文档问答 (Single-document QA):
      • NarrativeQA (Koisky et al., 2018): 基于小说和电影剧本的问答。
      • Qasper (Dasigi et al., 2021): 基于研究论文的问答。
      • MultiFieldQA (Bai et al., 2023): 多领域问答。
    • 多文档问答 (Multi-document QA):
      • HotpotQA (Yang et al., 2018): 需要从多个文档中进行多跳推理的问答。
    • 摘要 (Summarization):
      • GovReport (Huang et al., 2021): 政府报告摘要任务。
    • 少样本学习 (Few-shot Learning):
      • TriviaQA (Joshi et al., 2017): 问答任务,可能涉及长文本检索。

5.1.4. 模型选择

  • LongChat-v1.5-7b-32k (Li et al., 2023): 一个基于 Llama-2 扩展到 32K 上下文长度的开源模型。
  • Yarn-Llama-2-7b-128k (Peng et al., 2023): 另一个基于 Llama-2 扩展到 128K 上下文长度的开源模型。

5.2. 评估指标

论文使用了以下评估指标来衡量 Quest 的准确性和效率:

5.2.1. 语言建模:困惑度 (Perplexity, PPL)

概念定义: 困惑度是评估语言模型在给定文本上预测下一个词元的平均概率的指标。困惑度越低,表示模型对文本的预测能力越强,生成的文本越流畅和连贯。 数学公式: PPL(W)=exp(1Ni=1NlogP(wiw1,,wi1)) \mathrm{PPL}(W) = \exp\left( - \frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \dots, w_{i-1}) \right) 符号解释:

  • WW: 评估的文本序列。
  • NN: 序列中词元的数量。
  • wiw_i: 序列中的第 ii 个词元。
  • P(wiw1,,wi1)P(w_i | w_1, \dots, w_{i-1}): 模型根据前 i-1 个词元预测第 ii 个词元的概率。
  • exp\exp: 自然指数函数。
  • log\log: 自然对数。

5.2.2. Passkey Retrieval Task & LongBench QA 任务:F1 分数 (F1 Score)

概念定义: F1 分数是精确率 (Precision) 和召回率 (Recall) 的调和平均值。它综合考虑了模型识别关键信息(如通行密钥或问答答案)的准确性(精确率)和完整性(召回率)。F1 分数越高,表示模型在该任务上表现越好。 数学公式: F1=2PrecisionRecallPrecision+Recall \mathrm{F1} = 2 \cdot \frac{\mathrm{Precision} \cdot \mathrm{Recall}}{\mathrm{Precision} + \mathrm{Recall}} 其中,精确率 (Precision) 和召回率 (Recall) 定义为: Precision=True PositivesTrue Positives+False Positives \mathrm{Precision} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Positives}} Recall=True PositivesTrue Positives+False Negatives \mathrm{Recall} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}} 符号解释:

  • True Positives (TP): 模型正确识别为正例(例如,正确检索到通行密钥或问答答案)的数量。
  • False Positives (FP): 模型错误识别为正例(例如,错误检索或错误答案)的数量。
  • False Negatives (FN): 模型未能识别出正例(例如,未能检索到通行密钥或正确答案)的数量。

5.2.3. 效率评估:自注意力延迟 (Self-Attention Latency) 和 端到端延迟 (End-to-End Latency)

  • 自注意力延迟: 衡量执行自注意力计算所需的时间。这是 KV cache 优化的直接目标。
  • 端到端延迟: 衡量生成一个词元的总时间,包括自注意力、前馈网络 (FFN) 等所有步骤。这反映了实际用户体验到的推理速度。
  • 加速比 (Speedup): 定义为基线方法的延迟除以优化方法的延迟,表示优化方法比基线快多少倍。

5.2.4. 辅助指标:召回率 (Recall Rate) of Top-10 Attention Scores

概念定义: 在图 4 中使用,衡量的是 Quest 选择的词元中,有多少比例是原始完整注意力计算中得分最高的 Top-10 词元。这用于验证 Quest 的关键性估计方法能否有效地捕获真正的关键词元。 数学公式: 未给出明确公式,但根据描述可理解为: Recall Rate=被 Quest 选中的 Top-K 词元中,属于完整注意力 Top-10 的数量完整注意力 Top-10 词元的总数量 \text{Recall Rate} = \frac{\text{被 Quest 选中的 Top-K 词元中,属于完整注意力 Top-10 的数量}}{\text{完整注意力 Top-10 词元的总数量}} 符号解释:

  • Top-10 attention scores: 指的是在完整注意力计算中,对当前查询具有最高注意力得分的 10 个词元。
  • Quest 选中的 Top-K 词元: 指的是 Quest 算法根据其关键性估计选中的词元。

5.3. 对比基线

论文将 Quest 与以下 KV cache 优化算法和实现进行了比较:

  • H2O (Heavy-hitter Oracle) (Zhang et al., 2023b): 一种基于历史注意力分数总和来保留重要 KV cache 词元的驱逐算法。
  • TOVA (Transformers Are Multi-State RNNs) (Oren et al., 2024): 一种根据当前查询决定永久丢弃哪些词元的策略。
  • StreamingLLM (Efficient Streaming Language Models with Attention Sinks) (Xiao et al., 2023): 一种通过固定大小的窗口和注意力汇来处理流式长文本的算法。
  • FlashInfer (Ye et al., 2024): 一个用于 LLM 推理的高性能 CUDA 内核库。Quest 的效率评估正是基于 FlashInfer 的实现,并与 FlashInfer 的正常注意力实现进行比较。
  • Full KV Cache (Oracle Baseline): 使用完整的 KV cache 进行注意力计算,不进行任何稀疏化或驱逐。这通常代表了性能的上限(尽管效率最低),用于衡量稀疏化方法带来的准确性损失。

特殊设置说明:

  • 前两层不应用稀疏化: Quest 和所有基线算法都不应用于模型的前两层。这是因为作者的分析(图 3)表明,模型前两层的稀疏性较低,为了更好地保留模型准确性,这些层会使用完整的 KV cache
  • H2O 的特殊处理: H2O 需要计算完整的 O(n2)O(n^2) 注意力图来收集历史注意力分数,因此无法直接与 FlashAttention 结合用于长上下文。在 100K 序列长度的 passkey retrieval 测试中,为了使 H2O 可行,论文在预填充 (prefill) 阶段使用 FlashAttention 处理上下文,并在解码 (decode) 阶段才开始为 H2O 收集历史注意力分数。
  • 模拟解码: 对于 passkey retrievalLongBench 任务,为了模拟实际推理场景中 KV cache 驱逐的影响,论文将输入分为“材料 (material)”部分和“问题/指令 (question/instruction)”部分。材料部分使用带有完整 KV cacheFlashAttention 进行预填充,而问题/指令部分则通过逐词元 (token by token) 模拟解码,以观察不同方法对 KV cache 管理的影响。

6. 实验结果与分析

6.1. 核心结果分析

6.1.1. PG19 上的语言建模

以下是原文 Figure 6 的 VLM 描述: VLM 描述: 该图像是图表,展示了Quest在PG19数据集上的语言建模评估结果。图中横轴表示输入长度,纵轴为输出的困惑度(perplexity),曲线分别代表不同算法的性能,包括H2O*、TOVA*、完整缓存(Full)和Quest(本研究提出的方法)。可以看到,Quest的表现与完整缓存相近,且在长输入序列上保持较低的困惑度。插图部分详细展示了在输入长度接近32000时的困惑度变化情况。

Figure 6. Language modeling evaluation of Quest on PG19 dataset. We prompt the model with 0 to 32000 tokens from the PG19 test set and measure the perplexity of output tokens. \(\\mathrm { H } 2 \\mathrm { O } ^ { \\ast }\) and TOVA\\* indicate that for the first two layers of models, we do not apply these two algorithms to prune the KV Cache, as analyzed in Sec 3.4, which better preserves the model performance. Quest also uses a full cache in the first two layers of the model. Quest can closely match the performance of the full cache model. 该图像是图表,展示了Quest在PG19数据集上的语言建模评估结果。图中横轴表示输入长度,纵轴为输出的困惑度(perplexity),曲线分别代表不同算法的性能,包括H2O、TOVA*、完整缓存(Full)和Quest(本研究提出的方法)。可以看到,Quest的表现与完整缓存相近,且在长输入序列上保持较低的困惑度。插图部分详细展示了在输入长度接近32000时的困惑度变化情况。*

图 6 展示了在 PG19 数据集上,不同方法在 LongChat-7b-v1.5-32k 模型上的语言建模困惑度 (perplexity) 表现。

  • 评估设置: 模型使用不同长度(0 到 32000 词元)的 PG19 测试集进行提示 (prompting),并测量生成词元的困惑度。
  • 稀疏预算: H2OTOVAQuest 都使用 4096 词元的 KV cache 预算,这大约是总词元长度的 1/8。
  • 结果分析:
    • Quest 的表现: Quest 的困惑度曲线(绿色线)与 Full 缓存(蓝色线)的曲线非常接近,表明 Quest 在大幅稀疏 KV cache 的情况下,几乎没有带来准确性损失,能够保持与完整 KV cache 模型相当的语言建模能力。
    • 基线表现: H2OH2O*TOVATOVA*(星号表示模型前两层不应用稀疏化)的困惑度显著高于 Full 缓存和 Quest,尤其是在较长的上下文长度下。这说明这些查询无关 (query-agnostic) 或基于历史的 KV cache 驱逐策略,即便在第一、二层保留了完整缓存,也难以在语言建模任务中匹配 Quest 的准确性。
  • 结论: Quest 通过其查询感知稀疏性,能够在语言建模任务中有效地识别并保留关键词元,从而在保持高准确性的同时实现 KV cache 的大幅压缩。

6.1.2. 长文本通行密钥检索任务

以下是原文 Table 1 的结果: Table 1. (i) Results of 10k length passkey retrieval test on LongChat-7b-v1.5-32k. (ii) Results of 100k length passkey retrieval test on Yarn-Llama-2-7b-128k. Quest can achieve nearly perfect accuracy with 64 and 1024 tokens KV cache budget, which is about 1%1 \% of the total sequence length, demonstrating that Quest can effectively preserve the model's ability to handle long-dependency tasks. However, KV cache eviction algorithms such as H2O, TOVA, and StreamingLLM incorrectly discard the KV cache of the answer before receiving the question, thus failing to achieve ideal accuracy.

(i) LongChat-7b-v1.5-32k (10k context)
Method / Budget 32 64 128 256 512
H20 0% 1% 1% 1% 3%
TOVA 0% 1% 1% 3% 8%
StreamingLLM 1% 1% 1% 3% 5%
Quest (ours) 65% 99% 99% 99% 100%
(ii) Yarn-Llama-2-7b-128k (100k context)
Method / Budget 256 512 1024 2048 4096
H2O 2% 2% 2% 2% 4%
TOVA 2% 2% 2% 2% 10%
StreamingLLM 1% 1% 1% 2% 4%
Quest (ours) 88% 92% 96% 100% 100%

结果分析:

  • Quest 的卓越表现: 在 10K 长度测试中,Quest 在仅有 64 词元预算时就能达到 99% 的准确率,并在 512 词元预算时达到 100%。在 100K 长度测试中,Quest 在 1024 词元预算时达到 96%,并在 2048 词元预算时达到 100%。这表明 Quest 能够以极小的 KV cache 预算(大约总序列长度的 1%)有效地保留模型处理长依赖任务的能力,精准识别并保留了包含通行密钥的关键词元。
  • 基线方法的局限性: 相比之下,H2OTOVAStreamingLLM 在所有预算下都表现非常差,准确率通常低于 10%。这是因为这些基于历史信息或固定窗口的驱逐算法,在模拟解码过程中,可能会在收到问题之前错误地丢弃包含答案(通行密钥)的 KV cache 词元。例如,StreamingLLM 只能关注最近的文本窗口,如果通行密钥在该窗口之外,它就无法提供正确答案。
  • 查询感知稀疏性的优势: Quest 不会丢弃 KV cache,而是使用查询感知 (query-aware) 的方法来识别关键词元。当问题词元作为查询出现时,Quest 能够动态地识别并加载包含通行密钥的页面,从而获得几乎完美的准确率。

6.1.3. LongBench 结果

以下是原文 Figure 7 的 VLM 描述: VLM 描述: 该图像是多组实验结果的折线图,显示了在不同 KV 缓存预算下,算法 StreamingLLM、H2O、TOVA、Quest(我们的算法)与 Full 的 F1 分数对比。每个子图分别对应不同的任务,包括 Qasper、HotpotQA、GovReport、TriviaQA、NarrativeQA 和 MultifieldQA。

该图像是多组实验结果的折线图,显示了在不同 KV 缓存预算下,算法 StreamingLLM、H2O、TOVA、Quest(我们的算法)与 Full 的 F1 分数对比。每个子图分别对应不同的任务,包括 Qasper、HotpotQA、GovReport、TriviaQA、NarrativeQA 和 MultifieldQA。 该图像是多组实验结果的折线图,显示了在不同 KV 缓存预算下,算法 StreamingLLM、H2O、TOVA、Quest(我们的算法)与 Full 的 F1 分数对比。每个子图分别对应不同的任务,包括 Qasper、HotpotQA、GovReport、TriviaQA、NarrativeQA 和 MultifieldQA。

图 7 展示了在 LongBench 的六个数据集上,不同方法在 LongChat-7b-v1.5-32k 模型上的 F1 分数表现,每个任务都有不同的 KV cache 预算(128 到 4096 词元)。

  • 评估设置: 对于“材料 (material)”部分,使用 FlashAttention 与完整 KV cache 进行预填充。对于“问题/指令 (question/instruction)”部分,则通过逐词元 (token by token) 模拟解码。
  • 结果分析:
    • Quest 的持续优越性: Quest 在所有六个长上下文数据集上(Qasper, HotpotQA, GovReport, TriviaQA, NarrativeQA, MultifieldQA)持续优于所有基线方法。其 F1 分数曲线总是最接近 Full 缓存模型(灰色虚线)。
    • 以小预算实现高准确性: Quest 在 1K 词元预算下即可达到与完整 KV cache 模型相当的性能,而其他基线即使在更大的预算下也与 Full 缓存性能存在显著差距。
    • 近似无损性能: 在考虑前两层使用完整缓存的情况下,Quest 可以在不同的稀疏化比例下(例如,Qasper 1/6,GovReport 1/5,TriviaQA 1/10)实现无损性能。这证明 Quest 能够保持模型在不同类型的长上下文任务中的能力,不会因 KV cache 的不当丢弃而导致生成错误答案。
    • 基线再次受限: 与通行密钥任务类似,H2OTOVAStreamingLLM 在这些通用长上下文任务上表现较差,因为它们在解码过程中难以有效管理 KV cache,可能丢弃了对后续问题回答至关重要的信息。

6.2. 效率评估

为了证明 Quest 的可行性,作者基于 FlashInfer (Ye et al., 2024) 实现了一个完整的 CUDA 内核框架。

6.2.1. 内核级评估 (Kernel Evaluation)

以下是原文 Figure 8 的 VLM 描述: VLM 描述: 该图像是图表,包括两个部分,分别为(a)关键性估计和(b)近似注意力。部分(a)展示了不同序列长度下,FlashInfer和Quest算法(Quest-4, Quest-8, Quest-16, Quest-32)的归一化延迟比较,部分(b)则显示了不同序列长度下的延迟时间,通过不同的Quest设置(Quest-512, Quest-1024, Quest-2048, Quest-4096, Quest-8192)进行比较。

该图像是图表,包括两个部分,分别为(a)关键性估计和(b)近似注意力。部分(a)展示了不同序列长度下,FlashInfer和Quest算法(Quest-4, Quest-8, Quest-16, Quest-32)的归一化延迟比较,部分(b)则显示了不同序列长度下的延迟时间,通过不同的Quest设置(Quest-512, Quest-1024, Quest-2048, Quest-4096, Quest-8192)进行比较。

图 8. (a) 关键性估计 (Criticality estimation)

  • 评估内容: 衡量 Quest 关键性估计操作在不同序列长度和页面大小 (Page Size) 下的延迟。页面大小是 Quest-X 中的 XX
  • 结果分析:
    • 在短序列长度下,关键性估计的内存带宽利用率低于 FlashInfer,因为总的内存加载量不足以充分利用 GPU 内存带宽。
    • 随着序列长度的增加,关键性估计的相对性能提高,并趋近于 1/Page Size1 / \text{Page Size}。这是因为估计操作每页只读取一个词元(用于元数据)。
    • 技术如量化或更大的页面尺寸可以进一步减少额外的内存使用。

图 8. (b) 近似注意力 (Approximate attention)

  • 评估内容: 衡量 Quest 近似注意力操作在不同序列长度和词元预算下的延迟。
  • 结果分析:
    • Quest 的近似注意力与 PageAttention 兼容,通过将 Top-K 页面索引作为稀疏加载索引来实现。
    • 在给定的词元预算 BB 下,近似注意力的延迟是一个常数,与序列长度无关。
    • 当序列长度等于预算 BB 时,近似注意力与 FlashInfer 具有相似的延迟。

Top-K 过滤 (Top-K filtering):

  • 评估内容: 衡量 Top-K 过滤操作的延迟。Quest 使用 RAFT (Zhang et al., 2023a) 中的批处理 Top-K CUDA 算子。

  • 结果分析: 由于关键性估计将一个完整的词元简化为一个关键性得分,Top-K 过滤的内存移动量非常有限,因此开销极低,对于 128K 以下的序列长度,延迟仅为 5-10 微秒。

    以下是原文 Figure 9 的 VLM 描述: VLM 描述: 该图像是一个示意图,展示了在不同序列长度下,使用 Quest 方法进行关键性估计、Top-K 过滤和近似注意力的延迟表现。图中显示,随着序列长度的增加,使用 Top-K 过滤和近似注意力方法,能显著降低延迟,提升效率。

该图像是一个示意图,展示了在不同序列长度下,使用 Quest 方法进行关键性估计、Top-K 过滤和近似注意力的延迟表现。图中显示,随着序列长度的增加,使用 Top-K 过滤和近似注意力方法,能显著降低延迟,提升效率。 该图像是一个示意图,展示了在不同序列长度下,使用 Quest 方法进行关键性估计、Top-K 过滤和近似注意力的延迟表现。图中显示,随着序列长度的增加,使用 Top-K 过滤和近似注意力方法,能显著降低延迟,提升效率。

图 9 展示了 QuestLlama2-7B 模型上,在不同序列长度下的自注意力时间分解(使用 PyTorch profiler)。

  • 评估设置: 在 32K 序列长度和 2048 词元预算下进行评估。
  • 结果分析:
    • Quest 能够将自注意力时间相比 FlashInfer 减少 7.03×7.03 \times。这得益于内存移动量的减少。
    • 在短序列(如 1K)下,FlashInferQuest 的自注意力时间相似。但随着序列长度的增加,FlashInfer 的自注意力时间线性增长,而 Quest 由于其稀疏性,增长速度显著放缓。
    • 这证明 Quest 在长上下文场景下,能够通过减少实际执行注意力计算的词元数量,大幅提升自注意力层的效率。

6.2.2. 端到端评估 (End-to-End Evaluation)

以下是原文 Figure 10 的 VLM 描述: VLM 描述: 该图像是一个图表,展示了FlashInfer与不同上下文长度下的延迟对比。图中提供了FP16权重和4位权重(AWQ)的延迟数据,分别显示了在32768上下文长度时的延迟加速比为1.74x和2.23x。

该图像是一个图表,展示了FlashInfer与不同上下文长度下的延迟对比。图中提供了FP16权重和4位权重(AWQ)的延迟数据,分别显示了在32768上下文长度时的延迟加速比为1.74x和2.23x。 该图像是一个图表,展示了FlashInfer与不同上下文长度下的延迟对比。图中提供了FP16权重和4位权重(AWQ)的延迟数据,分别显示了在32768上下文长度时的延迟加速比为1.74x和2.23x。

图 10 展示了 Quest 在实际单批次 (single-batch) 推理场景下的端到端 (end-to-end) 加速效果。

  • 评估设置: 测量在解码阶段生成一个词元的平均延迟,涵盖不同序列长度和词元预算。比较对象是原始 FlashInfer 实现。
  • 结果分析:
    • Quest 持续优于 FlashInfer: 在所有序列长度下,Quest 都优于 FlashInfer
    • 延迟增长慢: 随着序列长度的增加,FlashInfer 的延迟显著增加,而 Quest 的延迟增长速度明显较慢,因为 Quest 维持了相似的词元预算。
    • 显著加速: 在 32K 序列长度和 2048 词元预算下:
      • 使用 FP16 权重时,Quest 实现了 1.74×1.74 \times 的推理速度提升。
      • 使用 4 位量化 (4-bit quantized weight) 时(如 AWQ),Quest 实现了 2.23×2.23 \times 的推理速度提升。这表明 Quest 与现有的量化技术兼容,并能进一步放大其效率优势。
  • 结论: Quest 不仅在内核级别实现了自注意力加速,在实际的端到端推理中也带来了显著的延迟降低,从而提高了 LLM 的整体吞吐量。

6.2.3. 与基线的比较

以下是原文 Figure 11 的 VLM 描述: VLM 描述: 该图像是一个图表,展示了在不同基准下,全量、TOVA和Quest方法的平均上下文长度和推理延迟。图中显示Quest方法在多项任务中表现出显著的延迟改善,最大化上下文长度的同时有效降低推理延迟,表明了其在长依赖任务上的优势。

该图像是一个图表,展示了在不同基准下,全量、TOVA和Quest方法的平均上下文长度和推理延迟。图中显示Quest方法在多项任务中表现出显著的延迟改善,最大化上下文长度的同时有效降低推理延迟,表明了其在长依赖任务上的优势。 该图像是一个图表,展示了在不同基准下,全量、TOVA和Quest方法的平均上下文长度和推理延迟。图中显示Quest方法在多项任务中表现出显著的延迟改善,最大化上下文长度的同时有效降低推理延迟,表明了其在长依赖任务上的优势。

图 11 (a) 展示了不同注意力机制在 LongBench 六个任务上达到“无损准确性 (lossless accuracy)”所需的词元预算。

  • 无损准确性定义: 达到与 Full KV Cache 模型相当的准确性。
  • 结果分析:
    • Quest 更高的稀疏性: Quest 达到无损准确性所需的词元预算显著小于 TOVAStreamingLLM。例如,在 NarrativeQA(平均上下文长度 24K 词元)上,TOVA 需要 14K 词元预算,而 Quest 仅需要 5K 词元,实现了更高的稀疏性。

    • 基线效率不足: H2OStreamingLLM 在某些任务中甚至无法达到无损准确性。

      图 11 (b) 展示了在相同准确性目标下,不同注意力方法的自注意力延迟。

  • 评估设置: 考虑到基线方法缺乏 CUDA 内核实现,作者对基线的效率进行了定性分析,使用 FlashInfer 的推理延迟作为基准,忽略了其他运行时开销。Quest 则在实际环境中进行评估。
  • 结果分析:
    • Quest 显著超越基线: Quest 在自注意力延迟方面显著优于所有基线方法。例如,在 GovReport 任务中,QuestTOVA 提高了 3.82×3.82 \times 的速度;在 TriviaQA 任务中,提高了 4.54×4.54 \times 的速度。
    • 高查询感知稀疏性是关键: 这种优势主要归因于 Quest 实现了更高的查询感知稀疏性,从而可以在不牺牲准确性的情况下大幅减少实际参与计算的词元数量。
  • 结论: Quest 能够在保持甚至超越基线准确性的同时,实现显著更高的效率。

7. 总结与思考

7.1. 结论总结

论文提出了 Quest,一种高效且准确的键值缓存 (KV cache) 选择算法,旨在解决大型语言模型 (LLMs) 长上下文推理过程中的效率瓶颈。Quest 的核心创新在于利用了查询感知稀疏性 (query-aware sparsity)。它通过在页面粒度上维护键 (Key) 向量的最小值和最大值作为元数据,并在每次解码时,利用当前查询向量 (Query vector) 动态估计每个 KV cache 页面的关键性。随后,Quest 仅对 Top-K 个最关键的页面执行自注意力计算,从而大幅减少了内存移动量。

实验证明,Quest 在多个长上下文基准测试中表现出色:

  • 效率: 在 32K 序列长度下,自注意力 (self-attention) 速度提升高达 7.03×7.03 \times,端到端 (end-to-end) 推理延迟降低 2.23×2.23 \times (使用 4 位量化时)。
  • 准确性:PG19 语言建模、passkey retrieval 任务和 LongBench 的六个长上下文任务上,Quest 在高稀疏度下仍能保持与完整 KV cache 模型相当的准确性,性能显著优于 H2OTOVAStreamingLLM 等现有基线。
  • 内存优化: 相比于加载整个 KV cacheQuest 显著减少了内存加载量,并且与现有的量化机制兼容。

7.2. 局限性与未来工作

论文中并未明确指出 Quest 的具体局限性,但可以从其方法和实验中推断一些潜在的方面:

  • 元数据开销: 尽管 Quest 采用页面粒度管理来降低元数据 (metadata) 开销,但维护每个页面 Key 向量的最小值和最大值仍然需要额外的存储和计算。对于极长的上下文或极大的模型,这部分开销可能仍需进一步优化。
  • Top-K 超参数: Top-K 页面数量是一个预设的超参数。如何自适应地确定最佳 KK 值以平衡准确性和效率,可能是一个需要进一步探索的方向。目前的固定 KK 值策略可能无法在所有场景下都达到最优。
  • “近似”的潜在风险: Quest 通过计算注意力得分的“上界”来估计关键性,这是一种近似方法。虽然实验结果表明其非常有效,但在某些极端情况下,真正的关键词元可能隐藏在得分上界不那么高的页面中,导致漏选(尽管其召回率表现良好)。
  • 算子优化: 尽管 Quest 实现了 CUDA 内核,但稀疏注意力算子的持续优化仍是提升效率的关键。尤其是如何高效地处理非连续的内存访问(由于页面选择导致的稀疏性),是硬件和软件协同优化的方向。
  • 多头注意力 (Multi-Head Attention) 的复杂性: 论文主要关注单个注意力机制的优化。在多头注意力中,不同头可能关注不同的上下文信息。如何将 Quest 的查询感知稀疏性扩展到多头注意力,并利用头部之间的异构性进行更精细的稀疏化,可能是一个研究方向。

7.3. 个人启发与批判

7.3.1. 个人启发

  1. “查询感知”的价值: Quest 的核心思想“查询感知稀疏性”提供了一个重要的启发。在 LLM 推理中,简单地基于历史或固定窗口进行 KV cache 驱逐过于粗糙,因为模型对上下文的关注点是动态变化的。未来的 LLM 优化应更多地从当前任务和查询的角度出发,动态地、智能地管理计算资源。
  2. 元数据驱动的近似: 利用简单的元数据(如最小值和最大值)来快速近似复杂的计算结果(如注意力得分上界)是一个非常高效的策略。这在计算资源受限的场景下(如边缘设备或高吞吐量服务器)具有广泛的应用前景。这种思想可以推广到其他 LLM 操作的稀疏化或近似计算中。
  3. 软硬件协同设计: Quest 的成功也得益于其在 FlashInfer 基础上实现的 CUDA 内核优化。这再次强调了 LLM 时代软硬件协同设计的重要性。仅仅有好的算法是不够的,还需要高效的底层实现来将理论优势转化为实际的加速。
  4. 长上下文的通用挑战: 论文通过 passkey retrievalLongBench 任务再次验证了现有 KV cache 驱逐算法在长依赖任务上的不足。这表明在追求长上下文能力时,如何有效地“记忆”和“检索”关键信息仍然是 LLM 面临的普遍挑战,而 Quest 提供了一个有效的解决方案。

7.3.2. 批判与潜在改进

  1. 关键性估计的粒度: Quest 在页面粒度进行关键性估计。虽然这降低了开销,但页面内的词元可能存在异构性,即一个页面中可能只有一个词元是关键的,但为了它需要加载整个页面。未来的工作可以探索更细粒度(如块级或词元级)的关键性估计,或者开发更智能的页面聚合策略。

  2. 动态 Top-K 策略: 当前 Top-K 是一个固定的超参数。在不同的层、不同的任务、甚至不同的词元生成阶段,所需的关键词元数量可能不同。开发一种能够根据模型层、注意力头、或当前任务动态调整 KK 值的策略,可能会带来进一步的优化。例如,可以基于模型的置信度或注意力分布的熵来动态调整 KK

  3. 结合其他优化: Quest 专注于 KV cache 的稀疏化。可以探讨将其与其他的 LLM 推理优化技术相结合,例如 Speculative Decoding (推测解码)、更先进的量化技术、或者不同层之间的计算卸载等,以实现更全面的端到端加速。

  4. 对 Q 向量的敏感性: Quest 的核心在于查询感知,其关键性估计直接依赖于查询向量 QQ。如果 QQ 向量本身存在噪声或不稳定性,是否会影响关键性估计的准确性?论文未深入探讨 QQ 向量质量对 Quest 性能的影响。

  5. 跨模型和架构的泛化性: 论文主要在 Llama-2 及其变体上进行了评估。尽管方法原理是通用的,但在其他 Transformer 架构(如 MistralGemma 等)上,其性能表现和最佳超参数可能有所不同,这需要进一步的验证。

    总的来说,Quest 提出了一种新颖且高效的 KV cache 优化策略,通过引入查询感知稀疏性,在长上下文 LLM 推理中取得了显著的性能提升,为未来的 LLM 加速研究提供了宝贵的思路。

相似论文推荐

基于向量语义检索推荐的相关论文。

暂时没有找到相似论文。