论文状态:已完成

SpargeAttention: Accurate and Training-free Sparse Attention Accelerating Any Model Inference

发表:2025/02/25
原文链接PDF 下载
价格:0.100000
价格:0.100000
价格:0.100000
已有 2 人读过
本分析由 AI 生成,可能不完全准确,请以原文为准。

TL;DR 精炼摘要

本文提出了SpargeAttention,一种通用且无需训练的稀疏注意力机制,旨在加速各种大模型的推理过程。通过两阶段在线过滤器,首次准确预测注意力图,跳过部分矩阵乘法;第二阶段则采用无额外开销的softmax感知过滤器,进一步提升效率。实验结果证明该方法在语言、图像及视频生成等任务中显著提升了性能,保持了端到端指标。

摘要

An efficient attention implementation is essential for large models due to its quadratic time complexity. Fortunately, attention commonly exhibits sparsity, i.e., many values in the attention map are near zero, allowing for the omission of corresponding computations. Many studies have utilized the sparse pattern to accelerate attention. However, most existing works focus on optimizing attention within specific models by exploiting certain sparse patterns of the attention map. A universal sparse attention that guarantees both the speedup and end-to-end performance of diverse models remains elusive. In this paper, we propose SpargeAttn, a universal sparse and quantized attention for any model. Our method uses a two-stage online filter: in the first stage, we rapidly and accurately predict the attention map, enabling the skip of some matrix multiplications in attention. In the second stage, we design an online softmax-aware filter that incurs no extra overhead and further skips some matrix multiplications. Experiments show that our method significantly accelerates diverse models, including language, image, and video generation, without sacrificing end-to-end metrics. The code is available at https://github.com/thu-ml/SpargeAttn.

思维导图

论文精读

中文精读

1. 论文基本信息

1.1. 标题

SpargeAttention: Accurate and Training-free Sparse Attention Accelerating Any Model Inference (SpargeAttention:准确且无需训练的稀疏注意力,加速任意模型推理)

1.2. 作者

  • Intao Zhang: 清华大学 (Tsinghua University)
  • Chendong Xiang: 清华大学 (Tsinghua University)
  • Haofeng Huang: 清华大学 & 智源人工智能研究院 (Tsinghua University & Beijing Academy of Artificial Intelligence)
  • Jia Wei: 清华大学 (Tsinghua University)
  • Haocheng Xi: 普林斯顿大学 (Princeton University)
  • Jun Zhu: 清华大学 (Tsinghua University)
  • Jianfei Chen: 清华大学 (Tsinghua University)

1.3. 发表期刊/会议

该论文发布于 arXiv 预印本平台。

1.4. 发表年份

2025年。

1.5. 摘要

大型模型因其二次时间复杂度,高效的注意力实现至关重要。幸运的是,注意力机制通常表现出稀疏性,即注意力图中的许多值接近于零,这使得可以省略相应的计算。许多研究已经利用稀疏模式来加速注意力。然而,大多数现有工作通过利用注意力图的特定稀疏模式,专注于优化特定模型内的注意力。一种能保证多种模型的加速和端到端性能的通用稀疏注意力仍然难以实现。在本文中,我们提出了 SpargeAttn,一种用于任何模型的通用稀疏量化注意力。我们的方法采用两阶段在线过滤器:在第一阶段,我们快速准确地预测注意力图,从而跳过注意力中的一些矩阵乘法。在第二阶段,我们设计了一个无额外开销的在线 softmax-aware 滤波器,进一步跳过一些矩阵乘法。实验表明,我们的方法显著加速了包括语言、图像和视频生成在内的多种模型,且不牺牲端到端指标。代码可在 https://github.com/thu-ml/SpargeAttn 获取。

1.6. 原文链接

  • arXiv 链接: https://arxiv.org/abs/2502.18137 (预印本)
  • PDF 链接: https://arxiv.org/pdf/2502.18137v8.pdf

2. 整体概括

2.1. 研究背景与动机

  • 核心问题: 随着大型模型中序列长度的增加(例如视频生成和语言模型中的 45K-128K),注意力 (attention) 机制的计算开销因其二次时间复杂度 (quadratic time complexity) 变得巨大,占据了推理延迟的很大一部分。这使得高效的注意力实现成为当前大型模型推理中的关键瓶颈。
  • 问题重要性: 注意力计算是 Transformer 架构的核心,直接影响了模型处理长序列的能力和推理速度。
  • 现有研究的挑战与空白 (Gap):
    • 稀疏性 (Sparsity) 的利用: 现有研究发现注意力图通常表现出稀疏性,即许多值接近零,这为跳过不必要的计算提供了机会。
    • 现有方法的局限性:
      • 通用性不足 (Limited Universality): 大多数现有稀疏注意力方法是为特定任务(如语言建模)设计的,利用了特定任务的稀疏模式(如滑动窗口、注意力槽)。然而,注意力模式在不同任务和模型中差异很大(如 Figure 2 所示),导致这些方法难以泛化。
      • 可用性挑战 (Usability Challenges): 准确预测注意力图中的稀疏区域需要精确性,而实现效率又要求预测开销最小。现有方法通常难以同时满足这两点。例如,MInference 在序列长度达到 100K 时才能实现显著加速,表明其预测开销可能较高。
  • 本文的切入点/创新思路: 旨在设计一种训练无关 (training-free) 的稀疏注意力操作符,能够通用地 (universally) 加速所有模型 (all models) 的推理,同时不损失任何指标 (metrics)。其核心思路是开发一种低开销、高精度的在线稀疏模式预测和过滤机制。

2.2. 核心贡献/主要发现

本文提出了 SpargeAttn,一种通用的稀疏量化注意力 (universal sparse and quantized attention),其主要贡献和发现包括:

  • 提出了通用的、训练无关的稀疏注意力 SpargeAttn: 旨在克服现有稀疏注意力方法的通用性和可用性限制,适用于语言、图像、视频生成等多种模型,无需对模型进行重新训练或微调。
  • 两阶段在线过滤机制: 这是 SpargeAttn 的核心创新,有效平衡了预测精度和计算效率。
    • 第一阶段:选择性词元压缩 (Selective Token Compression) 进行稀疏预测:
      • 观察到 QQKK 矩阵中相邻词元 (token) 存在高度相似性。
      • 通过将高自相似度 (high self-similarity) 的块压缩为单个代表性词元来快速预测注意力图的稀疏区域。
      • 使用 TopCdf 机制结合超参数 τ\tau 筛选出关键块,生成初步的块掩码 MgM_g
      • 关键创新点: 仅对高自相似度的块进行压缩,对于“固定块 (fix blocks)”(非自相似块)则始终进行计算,避免关键信息丢失,提高了预测准确性。
    • 第二阶段:稀疏 Warp Online Softmax 滤波器:
      • 在在线 Softmax 过程中,进一步识别注意力值足够小(对最终输出贡献可忽略)的 PV 乘积,从而跳过这些计算。
      • 该滤波器设计为无额外开销 (no extra overhead),通过在 GPU Warp 级别判断局部最大值与全局最大值的差异,实现更细粒度的稀疏化。
  • 与 SageAttention 结合实现额外加速: 将上述稀疏机制集成到 8 位量化的 SageAttention 框架中,由于稀疏和量化操作是正交的 (orthogonal),二者可叠加提供进一步的加速。
  • 引入 HilbertCurve 置换: 针对图像和视频模型,通过 HilbertCurve 置换策略,有效提高相邻词元之间的相似性,从而减少非自相似块的数量,进一步提高稀疏性。
  • 广泛的实验验证:
    • Llama3.1 (语言模型)、CogvideoXMochiOpen-Sora-Plan (文本到视频模型) 和 FluxStable-Diffusion3.5 (文本到图像模型) 等多种代表性模型上进行了评估。
    • 主要发现: SpargeAttn 在不牺牲任何端到端性能指标(如困惑度、视频质量、图像质量)的前提下,实现了显著的推理加速,比现有密集和稀疏注意力模型快 2.5 倍到 5 倍,并优于现有稀疏注意力基线。在某些任务中,甚至能增强 LLM 在长上下文任务中的表现。

3. 预备知识与相关工作

3.1. 基础概念

为了更好地理解 SpargeAttn,需要了解以下基础概念:

  • 注意力机制 (Attention Mechanism):

    • 概念定义: 注意力机制是神经网络中的一种技术,允许模型在处理序列数据时,动态地“聚焦”于输入序列中最重要的部分。在 Transformer 架构中,它通过计算查询(Query, QQ)、键(Key, KK)和值(Value, VV)之间的关系来生成输出。
    • 计算流程: 经典的自注意力 (Self-Attention) 计算通常包括三个主要步骤:
      1. 相似度计算: 计算 QQKK 之间的相似度,通常通过矩阵乘法 QKTQK^T 来实现。
      2. 缩放与 Softmax: 将相似度结果除以 dk\sqrt{d_k} 进行缩放(防止梯度过小或过大),然后应用 softmax 函数将结果归一化为概率分布,得到注意力权重矩阵 PP
      3. 加权求和: 将注意力权重矩阵 PPVV 矩阵相乘,得到最终的注意力输出,即对 VV 进行加权求和。
    • 数学公式: Attention(Q,K,V)=Softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
      • 符号解释:
        • QQ: 查询矩阵 (Query matrix),维度为 N×dkN \times d_k,其中 NN 是序列长度,dkd_k 是键/查询的维度。
        • KK: 键矩阵 (Key matrix),维度为 N×dkN \times d_k
        • VV: 值矩阵 (Value matrix),维度为 N×dvN \times d_v,其中 dvd_v 是值的维度。
        • QKTQK^T: 查询和键的点积,表示查询与每个键之间的相似度。
        • dk\sqrt{d_k}: 缩放因子,用于防止点积结果过大。
        • Softmax()\text{Softmax}(\cdot): softmax 函数,将相似度转换为概率分布,确保所有权重之和为 1。
        • P = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right): 注意力权重矩阵 (Attention Map),维度为 N×NN \times N
        • Attention(Q,K,V)\text{Attention}(Q, K, V): 最终的注意力输出矩阵,维度为 N×dvN \times d_v
  • 二次时间复杂度 (Quadratic Time Complexity):

    • 概念定义: 指算法的运行时间与输入规模的平方成正比。在注意力机制中,计算 QKTQK^T 矩阵乘法会产生一个 N×NN \times N 的注意力权重矩阵,其中 NN 是序列长度。这个矩阵的计算涉及 N2N^2 级别的操作,导致总时间复杂度为 O(N2)O(N^2)。当序列长度 NN 变得非常大时,计算开销会急剧增加,成为瓶颈。
  • 稀疏注意力 (Sparse Attention):

    • 概念定义: 是对标准注意力机制的一种优化,其核心思想是利用注意力权重矩阵(Attention Map)中固有的稀疏性。即,许多注意力权重的值非常小,对最终输出的贡献微乎其微,因此可以安全地省略这些部分的计算。稀疏注意力通过构建一个“稀疏掩码 (sparse mask)”来指示哪些重要的、非零的条目应该被计算,而哪些可以被跳过,从而减少计算量和内存占用。
  • FlashAttention (Dao et al., 2022; Dao, 2024):

    • 概念定义: 是一种内存高效且计算速度快的精确注意力实现。它通过平铺 (tiling) 策略将长序列的 QQ, KK, VV 矩阵分割成小块,并在 GPU 的共享内存中进行计算,减少了对昂贵的高带宽内存(HBM)的访问,从而显著加速。
    • 在线 Softmax (Online Softmax): FlashAttention 采用在线 softmax 算法,避免了在完整注意力权重矩阵上进行 softmax 归一化。它逐步计算 softmax 归一化因子,并将中间结果累积起来,从而避免存储整个 N×NN \times N 的注意力矩阵,进一步节省了内存。SpargeAttn 正是在 FlashAttention 的基础上进行稀疏化。
  • 量化 (Quantization):

    • 概念定义: 是一种模型优化技术,通过将模型参数(如权重、激活值)从高精度(如 FP32FP16)转换为低精度(如 INT8INT4)来减少模型的存储和计算开销。低精度数据类型占用内存更少,且在某些硬件上可以利用专门的低精度计算单元进行更快的矩阵乘法。SpargeAttn 结合了 SageAttention 提供的量化能力,实现了额外的加速。

3.2. 前人工作

论文将稀疏注意力方法分为三类,并提及了其他加速注意力的方法。

3.2.1. 稀疏注意力方法分类

  1. 基于模式的方法 (Pattern-based Methods):

    • 核心思想: 依赖于对注意力图的经验观察,预设一些固定的稀疏模式,例如滑动窗口 (sliding windows) 或注意力槽 (attention sinks)。
    • 示例:
      • H2O (Zhang et al., 2023), InfLLM (Xiao et al., 2024a), DUOAttention (Xiao et al., 2025) 依赖滑动窗口模式。
      • SampleAttention (Zhu et al., 2024), M0A (Fu et al., 2024), StreamingLLM (Xiao et al., 2024b) 依赖滑动窗口和注意力槽模式。
      • DitFastAttn (Yuan et al., 2024) 依赖滑动窗口模式以及不同注意力图之间的相似性,但其适用范围受限于简单的扩散 Transformer,不兼容语言模型和 MMDiT 模型。
    • 局限性: 这些方法通常缺乏通用性,因为注意力模式在不同模型和任务中变化很大。
  2. 动态稀疏方法 (Dynamic Sparse Methods):

    • 核心思想: 根据输入数据动态地构建稀疏掩码,无需预设模式,因此具有更大的通用潜力。
    • 子类别:
      • 通道压缩 (Channel Compression): 通过降低维度来计算注意力。
        • SparQAttn (Ribar et al., 2024) 和 LokiAttn (Singhania et al., 2024) 通过使用降低的维度进行全注意力计算来构建掩码。
        • 局限性: 常见注意力维度(如 64, 128)已经很小,进一步压缩带来的加速潜力有限。
      • 词元压缩 (Token Compression): 通过将词元块压缩为单个词元来计算注意力。
        • MInference (Jiang et al., 2024) 和 FlexPrefill (Lai et al., 2025) 将每个词元块压缩为单个词元,并在缩短的序列上计算注意力。
        • 局限性: 这种近似可能过于激进,如果压缩序列上的注意力分数不高,可能会遗漏注意力图中的重要块。
      • 训练额外参数: SeerAttention (Gao et al., 2024) 需要训练额外的参数来预测稀疏性,增加了使用成本。
    • 共同局限性: 多数动态稀疏方法是为语言模型设计的,在扩散模型等其他模型类型上的适用性尚不明确。
  3. 基于训练的方法 (Training-based Methods):

    • 核心思想: 修改注意力计算逻辑,需要对整个模型进行重新训练。
    • 示例: Reformer (Kitaev et al., 2020) 和 FastAttention (Pagliardini et al., 2023)。
    • 局限性: 相比训练无关的方法,使用成本更高,需要大量的计算资源和时间。

3.2.2. 其他注意力加速方法

这些方法与稀疏注意力是正交的 (orthogonal),意味着它们可以独立应用或组合使用以实现更大加速。

  • 内核实现优化 (Kernel Implementation Optimization):
    • FlashAttention (Dao et al., 2022), FlashAttention-2 (Dao, 2024), FlashAttention-3 (Shah et al., 2024) 等通过优化 GPU 底层内核,减少内存访问和提高并行度来加速注意力计算。
  • 量化 (Quantization):
    • SageAttention (Zhang et al., 2025d;a;g;e) 提出了一种量化方法来加速注意力。
  • 分布式计算 (Distributing Workload):
    • RingAttention (Liu et al., 2024a) 通过在多个设备上分布工作负载来处理超长序列。
  • 线性时间注意力 (Linear Time Attention):
    • Linformer (Wang et al., 2020), Performer (Choromanski et al., 2021), Metaformer (Yu et al., 2022), Linear Transformer (Katharopoulos et al., 2020) 等旨在将注意力机制的复杂度从二次降低到线性。

3.3. 技术演进与差异化分析

  • 技术演进: 注意力机制的优化历程从最初的精确但昂贵的 Softmax 注意力,逐步发展到利用其稀疏性来跳过计算(稀疏注意力),再到通过 FlashAttention 等技术优化底层实现和内存访问。同时,模型的量化也为推理加速提供了另一条途径。
  • 本文工作的定位: SpargeAttn 处于这一演进路径的前沿,它结合了动态稀疏化 (dynamic sparsification) 的思想(无需预设模式),FlashAttention 的高效计算框架,以及量化 (quantization) 的进一步加速潜力。
  • 与相关工作的核心区别和创新点:
    1. 通用性 (Universality): SpargeAttn 旨在解决现有稀疏注意力方法通用性不足的问题。它不依赖于任务特定的稀疏模式,而是基于更通用的“词元相似性”原理,这使得它能够有效应用于语言、图像和视频等多种模态。
    2. 两阶段在线过滤 (Two-stage Online Filter): 这是区别于单一稀疏预测方法的关键创新。
      • 第一阶段的选择性词元压缩 (Selective Token Compression) 在粗粒度上快速预测稀疏区域,并通过保留“固定块”来确保准确性。
      • 第二阶段的稀疏 Warp Online Softmax (Sparse Warp Online Softmax) 则在细粒度上、无额外开销地进一步优化,充分利用了在线 Softmax 的特性。这种分层过滤机制在准确性和效率之间取得了更好的平衡。
    3. 训练无关 (Training-free): 与需要重新训练模型的基于训练的方法(如 Reformer)不同,SpargeAttn 是一种即插即用 (plug-and-play) 的方案,大大降低了使用门槛。
    4. 与量化技术的正交集成 (Orthogonal Integration with Quantization): SpargeAttn 能够无缝集成到现有的量化框架(如 SageAttention)中,进一步叠加加速效果,这在工程实践中极具价值。
    5. HilbertCurve 置换 (HilbertCurve Permutation): 针对视觉模型,通过 HilbertCurve 置换策略提高相邻词元相似度,这是一个巧妙的领域知识利用,进一步优化了稀疏化潜力。

4. 方法论

4.1. 方法原理

SpargeAttn 的核心思想是在 FlashAttention 的基础上,引入一个两阶段在线过滤器 (two-stage online filter) 来识别并跳过注意力计算中不重要的部分。其直觉是,注意力图通常是稀疏的,许多值对最终输出的贡献很小,可以被安全地忽略。通过在计算密集型的矩阵乘法(QKTQK^TPV)发生之前进行智能过滤,可以显著减少计算量。

  • 第一阶段过滤: 在粗粒度上进行预测。通过分析查询(QQ)和键(KK)块内部的相似性,并结合压缩后的注意力图,预先判断哪些 QKTQK^TPV 块可以被完全跳过。这个阶段的目标是快速且准确地识别出大部分可跳过的区域。
  • 第二阶段过滤: 在细粒度上进行优化。在 FlashAttention 的在线 Softmax 计算过程中,进一步判断哪些 PV 乘积的贡献足够小,可以在 GPU Warp 级别进行跳过,且不引入额外的预测开销。
  • 结合量化: SpargeAttn 还集成了 8 位量化的 SageAttention,因为稀疏化和量化是正交的,可以叠加加速效果。

4.2. 核心方法详解

SpargeAttn 包含一个两阶段在线过滤器,并集成到稀疏 FlashAttention 框架中。

4.2.1. 稀疏 FlashAttention (Sparse FlashAttention)

SpargeAttn 采用 FlashAttention (Dao, 2024) 的平铺策略 (tiling strategy),并跳过被过滤掉的块。首先回顾 FlashAttention 的计算过程。

注意力操作定义为 S=QK/dS = QK^\top / \sqrt{d},然后 P = \sigma(S),最后 O=PVO = PV,其中 σ(S)ij=exp(Sij)/kexp(Sik)\sigma(S)_{ij} = \exp(S_{ij}) / \sum_k \exp(S_{ik})Softmax 操作。设序列长度为 NN,每个头维度为 ddQQ, KK, VV 的维度均为 N×dN \times d,而 SSPP 的维度为 N×NN \times N

FlashAttentionQQ, KK, VV 沿着词元维度 (token dimension) 分块,分别得到块 {Qi}\{Q_i\}, {Ki}\{K_i\}, {Vi}\{V_i\},块大小分别为 bqb_q, bkb_k, bkb_k。然后,它使用在线 Softmax (online softmax) (Milakov & Gimelshein, 2018) 逐步计算 OO 的每个块 OiO_i

Sij=QiKj/d,(mij,P~ij)=σ~(mi,j1,Sij),lij=exp(mi,j1mij)li,j1+rowsum(P~ij),Oij=diag(exp(mi,j1mij))Oi,j1+P~ijVj \begin{array}{r} S_{ij} = Q_i K_j^\top / \sqrt{d}, (m_{ij}, \widetilde{P}_{ij}) = \widetilde{\sigma}(m_{i, j-1}, S_{ij}), \\ l_{ij} = \exp(m_{i, j-1} - m_{ij}) l_{i, j-1} + \mathrm{rowsum}(\widetilde{P}_{ij}), \\ O_{ij} = \mathrm{diag}\left(\exp(m_{i, j-1} - m_{ij})\right) O_{i, j-1} + \widetilde{P}_{ij} V_j \end{array}

  • 符号解释:
    • SijS_{ij}: QiQ_i 块与 KjK_j 块之间的相似度得分矩阵。

    • (mij,P~ij)(m_{ij}, \widetilde{P}_{ij}): 由 FlashAttention 的在线 Softmax 机制计算得到。

    • σ~()\widetilde{\sigma}(\cdot): 这是一个类似 softmax 的操作符,它在每个迭代步骤中更新最大值 mijm_{ij} 和部分归一化的注意力权重 P~ij\widetilde{P}_{ij}。具体计算方式是:

      • mij=max{mi,j1,rowmax(Sij)}m_{ij} = \operatorname{max}\{m_{i, j-1}, \operatorname{rowmax}(S_{ij})\}:当前块 QiQ_i 处理到 KjK_j 块时,它会记录当前行(在 QiQ_i 内部)的最大 logitmlocal=rowmax(Sij)m_{\mathrm{local}} = \operatorname{rowmax}(S_{ij}),并与之前累积的最大 logitmi,j1m_{i, j-1} 取最大值,作为新的行最大值 mijm_{ij}。这个全局最大值用于后续的指数缩放,以避免数值溢出。
      • P~ij=exp(Sijmij)\widetilde{P}_{ij} = \exp(S_{ij} - m_{ij}):部分归一化的注意力权重,仅进行了指数化和减去最大 logit 值的操作,还未进行分母归一化。
    • mi,j1m_{i, j-1}: 前一个块累积的最大 logit 值,维度为 bq×1b_q \times 1

    • li,j1l_{i, j-1}: 前一个块累积的 softmax 分母,维度为 bq×1b_q \times 1

    • exp(mi,j1mij)\exp(m_{i, j-1} - m_{ij}): 缩放因子,用于调整之前累积的 softmax 分母 li,j1l_{i, j-1} 和输出 Oi,j1O_{i, j-1},以适应新的最大 logitmijm_{ij}

    • rowsum(P~ij)\mathrm{rowsum}(\widetilde{P}_{ij}): P~ij\widetilde{P}_{ij} 矩阵的行和,用于更新 softmax 分母 lijl_{ij}

    • Oi,j1O_{i, j-1}: 前一个块累积的部分输出。

    • P~ijVj\widetilde{P}_{ij} V_j: 当前块 QiQ_iKj,VjK_j, V_j 块计算得到的部分输出贡献。

    • mijm_{ij}lijl_{ij} 都是 bq×1b_q \times 1 的向量,初始值分别为 -\infty 和 0。

    • 最终输出 OiO_i 可以通过 Oi=diag(li,Tn)1Oi,TnO_i = \mathrm{diag}(l_{i, T_n})^{-1} O_{i, T_n} 计算得到,其中 TnT_nKKVV 的块数量。

      实现稀疏 FlashAttention 的方式很直观:通过跳过某些块矩阵乘法 QiKjQ_i K_j^\topP~ijVj\widetilde{P}_{ij} V_j 来加速注意力计算。本文基于 FlashAttention 提出了稀疏注意力的定义:

定义 1 (块掩码,Block Masks):MgM_gMpvM_{pv} 是维度为 N/bq×N/bk\left\lceil N / b_q \right\rceil \times \left\lceil N / b_k \right\rceil 的二元掩码,其中每个值都是 0 或 1。这些掩码决定了在稀疏注意力机制中跳过哪些计算。

  • 符号解释:
    • MgM_g: 全局稀疏掩码 (Global Sparse Mask),用于指示哪些 QiKjQ_i K_j^\topP~ijVj\widetilde{P}_{ij} V_j 乘积可以被完全跳过。
    • MpvM_{pv}: PV 稀疏掩码 (PV Sparse Mask),用于指示在 MgM_g 不跳过的情况下,哪些 P~ijVj\widetilde{P}_{ij} V_j 乘积可以被进一步跳过。

定义 2 (稀疏 FlashAttention,Sparse FlashAttention): 基于掩码的稀疏 FlashAttention 计算规则定义如下: QiKj,P~ijVjareskippedifMg[i,j]=0.P~ijVjis skippedifMpv[i,j]=0. \begin{array}{r l} & Q_i K_j^\top, \widetilde{P}_{ij} V_j \mathrm{are skipped if} M_g[i, j] = 0. \\ & \qquad \widetilde{P}_{ij} V_j \mathrm{is~skipped if} M_{pv}[i, j] = 0. \end{array}

  • 符号解释:
    • 如果 Mg[i,j]=0M_g[i, j] = 0,则 QiKjQ_i K_j^\topP~ijVj\widetilde{P}_{ij} V_j 都被跳过。
    • 如果 Mpv[i,j]=0M_{pv}[i, j] = 0,则只有 P~ijVj\widetilde{P}_{ij} V_j 被跳过。这意味着 QiKjQ_i K_j^\top 仍然被计算,但其后续的加权求和部分被跳过。

4.2.2. 选择性词元压缩用于稀疏预测 (Selective Token Compression for Sparse Prediction)

这是 SpargeAttn第一阶段在线过滤,目标是快速准确地预测注意力图中的稀疏块,从而跳过 QiKjQ_i K_j^\topP~ijVj\widetilde{P}_{ij} V_j 的部分计算。

关键思想 (Key Idea): 尽管注意力图在不同模型中有所不同 (Figure 2),但作者观察到不同模型的查询(QQ)和键(KK)矩阵中的大多数相邻词元 (neighboring tokens) 之间表现出高度相似性 (Figure 4)。基于此观察,对于由高度相似词元组成的块,可以将这些词元压缩 (consolidate) 为一个单一的代表性词元。

预测算法 (Prediction Algorithm):Figure 3 中的 Step1 所示,该算法通过以下步骤生成稀疏掩码 MgM_g

  1. 压缩查询和键块:

    • 对于 QQ 矩阵的每个块 QiRbq×dQ_i \in \mathbb{R}^{b_q \times d}KK 矩阵的每个块 KjRbk×dK_j \in \mathbb{R}^{b_k \times d}
      • 计算每个块内部词元的平均值 (mean),得到压缩后的单个词元 qiR1×dq_i \in \mathbb{R}^{1 \times d}kjR1×dk_j \in \mathbb{R}^{1 \times d}
      • 计算每个块内部词元的平均余弦相似度 (mean cosine similarity) sqis_{qi}skjs_{kj},这用于衡量块的自相似性。
    • CosSim(X)\mathrm{CosSim}(X) 的定义: CosSim(X) = mean(XXmax(XX)) \begin{array}{r} \mathrm{CosSim}(X) ~ = ~ \mathrm{mean}\big( \frac{X X^\top}{\mid \mathrm{max}(X X^\top) \mid} \big) \end{array}
      • 符号解释:
        • XX: 输入的词元块矩阵(可以是 QiQ_iKjK_j)。
        • XXX X^\top: 块内词元之间的点积矩阵,反映相似性。
        • max(XX)\mid \mathrm{max}(X X^\top) \mid: 块内词元点积的最大绝对值,用于归一化。
        • mean()\mathrm{mean}(\cdot): 对归一化后的相似度矩阵求平均,得到一个标量值,表示块的平均自相似度。
  2. 计算压缩注意力图 P^\hat{P}

    • 使用压缩后的词元 qiq_ikjk_j 计算一个压缩版的相似度矩阵 S^\hat{S}q={qi}={mean(Qi,axis=0)}k={kj}={mean(Kj,axis=0)}sqi=CosSim(Qi),skj=CosSim(Kj)S^[i]=qik \begin{array}{r l} & q = \{ q_i \} = \{ \mathrm{mean}(Q_i, \mathrm{axis}=0) \} \\ & k = \{ k_j \} = \{ \mathrm{mean}(K_j, \mathrm{axis}=0) \} \\ & s_{qi} = \mathrm{CosSim}(Q_i), s_{kj} = \mathrm{CosSim}(K_j) \\ & \hat{S}[i] = q_i k^\top \end{array}
      • 符号解释:
        • q={qi}q = \{q_i\}: 所有 QQ 块的压缩词元集合。
        • k={kj}k = \{k_j\}: 所有 KK 块的压缩词元集合。
        • sqis_{qi}: QiQ_i 块的自相似度。
        • skjs_{kj}: KjK_j 块的自相似度。
        • S^[i]\hat{S}[i]: 压缩后的第 ii 个查询块与所有键块的相似度得分行向量。
  3. 处理非自相似块(Fix Blocks):

    • 为了防止非自相似块(fix blocks,即自相似度低于超参数 θ\theta 的块)的干扰,将这些块对应的 S^\hat{S} 值设置为 -\inftyS^[:,j]=,If skj<θ \hat{S}[:, j] = -\infty, \mathrm{If~} s_{kj} < \theta
      • 符号解释: 如果键块 KjK_j 的自相似度 skjs_{kj} 小于阈值 θ\theta,则所有查询块 qiq_ikjk_j 的压缩相似度 S^[i,j]\hat{S}[i, j] 都被设置为 -\infty。这确保了在后续的 Softmax 计算中,这些非自相似键块不会被错误地赋予低注意力权重。
  4. 应用 Softmax 生成 P^\hat{P}

    • S^[i]\hat{S}[i] 应用 Softmax 函数,得到压缩注意力图 P^\hat{P}P^[i]=Softmax(S^[i]) \hat{P}[i] = \mathrm{Softmax}(\hat{S}[i])
      • 符号解释: P^[i]\hat{P}[i] 是压缩后的注意力分布,表示第 ii 个查询块对所有键块的注意力权重。
  5. 生成稀疏掩码 MgM_g

    • 对于 P^\hat{P} 的每一行 P^[i]\hat{P}[i],使用 TopCdf 函数选择累积和达到超参数 τP^[i]\tau \cdot \sum \hat{P}[i] 的位置,这些位置在 Mg[i,:]M_g[i, :] 中被设置为 1。
      • TopCdf 函数定义:
        def Top_Cdf(P[i], tau):
            sorted_P, idx = torch.sort(P[i], descending = True)
            cusum_P = torch.cumsum(sorted_P, dim != 0)
            mask = cusum_P <= tau * P[i].sum()
            M_i = torch.zeros_like(mask)
            M_i[idx] = mask
            return M_i
        
        • 符号解释:
          • P[i]: 压缩注意力图 P^\hat{P} 的第 ii 行,表示第 ii 个查询块对所有键块的注意力权重。
          • tau: 一个超参数,介于 (0, 1) 之间,表示累积注意力权重的阈值。
          • sortedP,idxsorted_P, idx: P[i] 降序排序后的值和原始索引。
          • cusumPcusum_P: 排序后 P[i] 的累积和。
          • mask: 布尔掩码,指示哪些位置的累积和小于等于 τP[i].sum()\tau \cdot \text{P[i].sum()}
          • MiM_i: 最终的二元掩码行,其原始索引对应位置被设置为 1。
        • 目的: TopCdf 函数旨在保留那些贡献了大部分注意力权重的键块,从而在保证精度的前提下实现稀疏化。
  6. 强制保留非自相似块的计算:

    • 为了确保包含非自相似块(fix blocks)的计算不被跳过,需要修正 MgM_g
    • 如果查询块 QiQ_i 的自相似度 sqis_{qi} 小于阈值 θ\theta,则 Mg[i,:]M_g[i, :] 的所有值被设置为 1,即该行查询块与所有键块的注意力都将被计算。
    • 如果键块 KjK_j 的自相似度 skjs_{kj} 小于阈值 θ\theta,则 Mg[:,j]M_g[:, j] 的所有值被设置为 1,即所有查询块与该列键块的注意力都将被计算。 Mg[i,:]=1, If sqi<θ;Mg[:,j]=1, If skj<θ M_g[i, :] = 1, \mathrm{~If~} s_{qi} < \theta ; \quad M_g[:, j] = 1, \mathrm{~If~} s_{kj} < \theta
      • 目的: 这一步至关重要,它确保了那些难以用单个压缩词元很好地表示的“复杂”块(fix blocks)不会因为压缩预测的近似性而被错误地忽略,从而保证了精度。

4.2.3. 稀疏 Warp 在线 Softmax (Sparse Warp Online Softmax)

这是 SpargeAttn第二阶段在线过滤,旨在进一步在在线 Softmax 过程中识别并跳过那些足够小的注意力值对应的 P~ijVj\widetilde{P}_{ij} V_j 乘积,且不引入额外开销 (no extra overhead)

关键思想 (Key Idea):FlashAttention 的内循环中,如果某个注意力块 P~ij\widetilde{P}_{ij} 的所有值都非常接近零,那么其对最终输出 OijO_{ij} 的贡献 P~ijVj\widetilde{P}_{ij} V_j 将是可忽略的。

原理分析: 回顾 FlashAttentionOijO_{ij} 的更新公式: mlocal=rowmax(Sij), mij=max{mi,j1,mlocal}Oij=diag(exp(mi,j1mij))Oi,j1+P~ijVj \begin{array}{r} m_{\mathrm{local}} = \mathrm{rowmax}(S_{ij}), ~ m_{ij} = \mathrm{max}\{m_{i, j-1}, m_{\mathrm{local}}\} \\ O_{ij} = \mathrm{diag}\left(\mathrm{exp}(m_{i, j-1} - m_{ij})\right) O_{i, j-1} + \widetilde{P}_{ij} V_j \end{array}

  • 符号解释:
    • mlocal=rowmax(Sij)m_{\mathrm{local}} = \mathrm{rowmax}(S_{ij}): 当前 SijS_{ij} 块中每行的最大 logit 值。

    • mij=max{mi,j1,mlocal}m_{ij} = \mathrm{max}\{m_{i, j-1}, m_{\mathrm{local}}\}: 当前迭代累积的最大 logit 值。

    • 其他符号同 4.2.1 节。

      如果 rowmax(Sij)<mij\operatorname{rowmax}(S_{ij}) < m_{ij} 成立,那么 mijm_{ij} 将等于 mi,j1m_{i, j-1}(即当前块的最大值并没有超过历史最大值)。在这种情况下,Oij=Oi,j1+P~ijVjO_{ij} = O_{i, j-1} + \widetilde{P}_{ij} V_j。 进一步地,如果 rowmax(Sij)\operatorname{rowmax}(S_{ij}) 显著小于 mijm_{ij},则 P~ij=exp(Sijmij)\widetilde{P}_{ij} = \exp(S_{ij} - m_{ij}) 中的所有值都将非常接近 0。这意味着 P~ijVj\widetilde{P}_{ij} V_j 的所有值也将接近 0,从而对 OijO_{ij} 的更新贡献可以忽略不计。

这个条件可以表述为: OijOi,j1,if max(exp(Sijmij))0max(exp(Sijmij))0max(mlocalmij)<λ \begin{array}{r} O_{ij} \approx O_{i, j-1}, \quad \mathrm{if~} \operatorname{max}\left( \exp(S_{ij} - m_{ij}) \right) \to 0 \\ \operatorname{max}\left( \exp(S_{ij} - m_{ij}) \right) \to 0 \Leftrightarrow \operatorname{max}(m_{\mathrm{local}} - m_{ij}) < \lambda \end{array}

  • 符号解释:
    • max(exp(Sijmij))0\max(\exp(S_{ij} - m_{ij})) \to 0 时,意味着 P~ij\widetilde{P}_{ij} 矩阵中最大的元素也接近于 0,整个 P~ijVj\widetilde{P}_{ij} V_j 贡献可忽略。
    • 这个条件等价于 max(mlocalmij)<λ\max(m_{\mathrm{local}} - m_{ij}) < \lambda,其中 λ\lambda 是一个负数超参数。当 mlocalm_{\mathrm{local}} 远小于 mijm_{ij} 时,差值 mlocalmijm_{\mathrm{local}} - m_{ij} 是一个较大的负数,指数 exp(mlocalmij)\exp(m_{\mathrm{local}} - m_{ij}) 就会非常小。当这个差值的最大值小于某个负阈值 λ\lambda 时,就认为可以跳过计算。

稀疏化机制 (Sparsification Mechanism): 基于上述分析,本文提出了一个简单而有效的稀疏方法来跳过 P~ijVj\widetilde{P}_{ij} V_j 的计算。在 FlashAttention 的内循环中,S_ij 会被分割给 cwc_wGPU Warps 进行并行处理。具体而言,对于第 iwi_wwarp 处理的 Sij[Iw]S_{ij}[I_w] 部分 (其中 Iw=[iwbqcw:(iw+1)bqcw]I_w = [\frac{i_w \cdot b_q}{c_w} : \frac{(i_w+1) \cdot b_q}{c_w}] 是该 warp 负责的行索引范围):

  • 如果 max(mlocal[Iw]mij[Iw])<λ\operatorname{max}(m_{\mathrm{local}}[I_w] - m_{ij}[I_w]) < \lambda 成立,则认为这部分 P~ij[Iw]Vj\widetilde{P}_{ij}[I_w] V_j 的贡献可以忽略。
  • 此时,对应的 Oij[Iw]Oi,j1[Iw]O_{ij}[I_w] \approx O_{i, j-1}[I_w],该 warp 将跳过 P~ij[Iw]Vj\widetilde{P}_{ij}[I_w] V_j 的计算和对 Oij[Iw]O_{ij}[I_w] 的更新。

4.2.4. 与 SageAttention 结合 (Combined with SageAttention)

为了进一步加速,SpargeAttn 将其方法集成到 SageAttention (Zhang et al., 2025a;d;g;b;e) 中。

  • 正交性: 量化 (quantization) 操作和稀疏 (sparse) 操作是正交的,因此稀疏计算可以直接应用于 SageAttentionSageAttention 是一种量化方法,通过将 QQ, KK 矩阵量化为 8 位整数(或更低精度)来加速计算。
  • 集成方式:
    • SageAttention 内循环的开始(对应于 Algorithm 1 中的第 10 行),添加一个判断,根据第一阶段生成的掩码 Mg[i,j]M_g[i, j] 来决定是否跳过整个内循环块的计算。
    • SageAttention 内循环中更新 OijO_{ij} 之前(对应于 Algorithm 1 中的第 15 行),添加另一个判断,根据第二阶段的稀疏 Warp Online Softmax 原理来决定是否跳过 P~ij[Iw]Vj\widetilde{P}_{ij}[I_w] V_j 的计算。
  • 预测开销优化: 为了最小化注意力图预测的开销,预测过程在 CUDA 中实现,并采用了内核融合 (kernel fusion) 技术。

算法 1: SpargeAttn 的实现 (Implementation of SpargeAttn)

Algorithm 1 Implementation of SpargeAttn.

1: Input: Matrices Q(FP16), K(FP16), V(FP16) ∈ ℝ^(N×d), block size b_q, b_kv, count of GPU Warps c_w, hyper-parameters τ, θ, and λ
2: Divide Q to T_m = N / b_q blocks {Q_i}; divide K, V to T_n = N / b_kv blocks {K_i} and {V_i}.
3: Q̂_i, K̂_j, δ_Q, δ_K = Quant(Q_i, K_j); // per-block quantization of SageAttention.
4: q = {q_i} = {mean(Q_i, axis=0)}; k = {k_j} = {mean(K_j, axis=0)};
5: Ŝ = qkᵀ; s_qi = CosSim(Q_i); s_kj = CosSim(K_j); Ŝ[:, j]} = −∞, If s_kj < θ;
6: P̂[i] = Softmax(Ŝ[i]); M[i, :] = TopCdf(P̂[i], τ); M[i, :] = 1, If s_qi < θ; M[:, j] = 1, If s_kj < θ;
7: for i = 1 to T_m do
8:   Load Q̂_i and δ_Q[i] into a SM;
9:   for j in [1, T_n] do
10:     if M[i, j] != 0 then // First stage filter: check if block (i, j) needs to be computed
11:       Load K̂_j, V̂_j, and δ_K[j] into the SM;
12:       S_ij = Matmul(Q̂_i, K̂_jᵀ) × δ_Q × δ_K; // Dequantization of SageAttention.
13:       m_local = rowmax(S_ij); m_ij = max(m_i, j-1, m_local); P̃_ij = exp(S_ij - m_ij); l_ij = e^(m_i, j-1 - m_ij) l_i, j-1 + rowsum(P̃_ij);
14:       i_w = range(c_w); I_w = [(i_w * b_q) / c_w : ((i_w + 1) * b_q) / c_w];
15:       if max(m_local[I_w] - m_ij[I_w]) > λ then // Second stage filter: check if this warp's P̃_ij V_j needs to be computed
16:         O_ij[I_w] = diag(e^(m_i, j-1[I_w] - m_ij[I_w])) O_i, j-1[I_w] + Matmul(P̃_ij[I_w], V_j); // Paralleled by c_w warps.
17:       end if
18:     end if
19:   end for
20:   O_i = diag(l_i, T_n)⁻¹ O_i, T_n;
21:   Write O_i;
22: end for
23: return O = {O_i};
  • 算法逐行解释:
    • 行 1: 输入 (Input): 接收 QQ, KK, VV 矩阵 (FP16 精度),块大小 bq,bkvb_q, b_{kv},GPU Warps 数量 cwc_w,以及超参数 τ,θ,λ\tau, \theta, \lambda
    • 行 2: 分块 (Block Division):QQ 矩阵分成 TmT_m 个块 {Qi}\{Q_i\},将 KKVV 矩阵分成 TnT_n 个块 {Ki}\{K_i\}{Vi}\{V_i\}
    • 行 3: 量化 (Quantization): 使用 SageAttention 的逐块量化方法,将 Qi,KjQ_i, K_j 量化为 Q^i,K^j\hat{Q}_i, \hat{K}_j,并记录量化比例因子 δQ,δK\delta_Q, \delta_K
    • 行 4: 计算压缩词元 (Compressed Tokens): 对每个 QiQ_iKjK_j 块,计算其内部词元的平均值,得到压缩词元 qiq_ikjk_j
    • 行 5: 计算压缩相似度与自相似度 (Compressed Similarity & Self-Similarity):
      • 计算压缩词元 qqkk 的点积 S^=qk\hat{S} = qk^\top
      • 计算每个 QiQ_i 块的自相似度 sqi=CosSim(Qi)s_{qi} = \mathrm{CosSim}(Q_i)KjK_j 块的自相似度 skj=CosSim(Kj)s_{kj} = \mathrm{CosSim}(K_j)
      • 如果 KjK_j 的自相似度 skjs_{kj} 小于阈值 θ\theta,则将 S^\hat{S} 中对应列 jj 的所有值设置为 -\infty,以在 Softmax 归一化后忽略这些非自相似键块的贡献。
    • 行 6: 生成稀疏掩码 MgM_g (Generate Sparse Mask MgM_g):
      • S^[i]\hat{S}[i] 应用 Softmax 函数得到压缩注意力图 P^[i]\hat{P}[i]
      • 使用 TopCdf 函数根据 P^[i]\hat{P}[i]τ\tau 生成初步的稀疏掩码行 M[i, :]
      • 自相似度修正 (Self-Similarity Correction): 如果 QiQ_i 的自相似度 sqis_{qi} 小于 θ\theta,则 M[i, :] 的所有值被强制设置为 1。如果 KjK_j 的自相似度 skjs_{kj} 小于 θ\theta,则 M[:, j] 的所有值被强制设置为 1。这确保了非自相似块的计算不被跳过。
    • 行 7-22: FlashAttention 的主循环 (Main FlashAttention Loop):
      • 行 7: 遍历所有查询块 QiQ_i
      • 行 8: 将量化后的 QiQ_iQ^i\hat{Q}_i 及其比例因子 δQ[i]\delta_Q[i] 加载到 GPU 的共享内存 (SM) 中。
      • 行 9: 遍历所有键值块 Kj,VjK_j, V_j
      • 行 10: 第一阶段过滤 (First Stage Filter): 检查掩码 M[i, j]。如果为 0,表示该块的计算可以完全跳过,直接进入下一次内循环迭代。
      • 行 11: 如果 M[i,j]0M[i, j] \ne 0,则将量化后的 KjK_jK^j,V^j\hat{K}_j, \hat{V}_j 及其比例因子 δK[j]\delta_K[j] 加载到共享内存 (SM) 中。
      • 行 12: 计算相似度矩阵 SijS_{ij} (Compute SijS_{ij}): 计算 Q^i\hat{Q}_iK^j\hat{K}_j^\top 的矩阵乘法,并乘以量化比例因子 δQ×δK\delta_Q \times \delta_K 进行反量化 (dequantization),得到 SijS_{ij}
      • 行 13: 在线 Softmax 核心计算 (Online Softmax Core): 计算局部最大值 mlocalm_{\mathrm{local}},更新全局最大值 mijm_{ij},计算部分归一化注意力权重 P~ij\widetilde{P}_{ij},并更新 softmax 分母 lijl_{ij}
      • 行 14: Warp 分割 (Warp Division): 将当前 QiQ_i 块中的行索引范围 bqb_q 分割给 cwc_wGPU Warps,每个 warp 负责 IwI_w 范围的行。
      • 行 15: 第二阶段过滤 (Second Stage Filter): 对于每个 warp,检查它负责的行范围内 max(mlocal[Iw]mij[Iw])\operatorname{max}(m_{\mathrm{local}}[I_w] - m_{ij}[I_w]) 是否大于阈值 λ\lambda。如果大于 λ\lambda,表示这部分注意力贡献显著,需要计算。
      • 行 16: 如果条件满足,则计算 P~ij[Iw]Vj\widetilde{P}_{ij}[I_w] V_j 并更新 Oij[Iw]O_{ij}[I_w]。此操作由 cwc_wwarps 并行执行。如果条件不满足(即 max()λ\operatorname{max}(\dots) \le \lambda),则跳过这部分 PV 乘积的计算。
      • 行 20: 在处理完所有键值块后,对 OiO_i 进行最终的归一化,得到最终输出 OiO_i
      • 行 21:OiO_i 写回全局内存。
    • 行 23: 返回 (Return): 返回所有输出块组成的最终输出 O={Oi}O = \{O_i\}

4.2.5. 超参数确定 (Hyper-parameters Determination for Model Layer)

SpargeAttn 涉及三个超参数:

  • τ(0,1)\tau \in (0, 1): 用于 TopCdf 函数,控制第一阶段的稀疏度,表示保留累积注意力权重的比例。
  • θ(1,1)\theta \in (-1, 1): 用于自相似度判断,控制哪些块被视为“非自相似块”而强制计算。
  • λ<0\lambda < 0: 用于稀疏 Warp Online Softmax,控制第二阶段的稀疏度,表示局部最大值与全局最大值的差值阈值。

确定过程: 超参数的确定是针对模型中每个注意力层 (each attention layer) 进行的,以确保最大化稀疏性的同时保持注意力精度。

  1. 精度评估指标: 使用严格的相对 L1 距离 (Relative L1 distance) 作为注意力误差指标,定义为 L1=OO/OL1 = \sum |O - O'| / \sum |O|。其中 OO 是全注意力 (Full-Attention) 的输出,OO' 是稀疏注意力的输出。
  2. 两阶段网格搜索 (Two-stage Grid Search):
    • 第一阶段 (τ,θ\tau, \theta): 首先对 τ\tauθ\theta 进行网格搜索,目标是找到最大化稀疏性的最优对,同时要求 L1<l1L1 < l_1。其中 l1l_1 是一个预设的 L1 误差阈值(例如 l1=0.05l_1 = 0.05)。
    • 第二阶段 (λ\lambda): 在确定了 τ\tauθ\theta 后,再对 λ\lambda 进行网格搜索,以进一步最大化稀疏性,同时要求 L1<l2L1 < l_2。其中 l2l_2 是另一个预设的 L1 误差阈值(例如 l2=0.06l_2 = 0.06),通常略高于 l1l_1,允许一定的额外误差以换取更高稀疏性。 该过程在五个不同模型输入 (five different model inputs) 上进行评估,以确保超参数的鲁棒性。

4.2.6. 希尔伯特曲线置换 (HilbertCurve Permutation)

关键思想 (Key Idea): 为了在保证准确性的前提下提高稀疏性,一个核心挑战是增加查询和键块的自相似度 (self-similarity)。自相似度越高,被识别为“固定块 (fix blocks)”(即需要强制计算的非自相似块)的数量就越少,从而更多的块可以参与 TopCdf 选择,提高整体稀疏性。由于注意力计算对词元置换是不变的 (computationally invariant to token permutations),问题就转化为找到一个能增强相邻词元相似性的置换方法。

方法 (Method):

  • 针对图像和视频模型,这些模型具有强大的先验知识 (priors):相邻像素或帧通常是相似的。

  • 本文提出了希尔伯特曲线置换 (HilbertCurve Permutation)。对于 3D 视觉词元 ΛˉQ,K,VRT×H×W×d\mathbf{\bar{\Lambda}} Q, K, V \in \mathbb{R}^{T \times H \times W \times d} (其中 TT 是时间,HH 是高度,WW 是宽度,dd 是维度),使用 HilbertCurve 填充 3D 空间,然后沿着曲线将词元展平为形状 RL×d\mathbb{R}^{L \times d} 的 1D 序列,其中 L=T×H×WL = T \times H \times W

  • 优势: HilbertCurve 有效地保留了局部性 (preserves locality),它在遍历整个 3D 空间时不会跨越行或列,从而增加了相邻词元之间的相似性,进而提高了注意力的稀疏性。

    下图 Figure 5 展示了 1x6x6 空间中 token 置换的例子。

    Figure 5. Illustration of different token permutation methods in \(1 \\times 6 \\times 6\) space, with block size of 4. 该图像是示意图,展示了在 1 imes 6 imes 6 空间中不同的 token 置换方法,块大小为 4。图中标示了视觉 tokens 的排列方式,以及块内和块间的相似性。箭头表示重新排列的路径,颜色深浅显示不同的区块性质。

Figure 5. Illustration of different token permutation methods in 1×6×61 \times 6 \times 6 space, with block size of 4.

  • 图 5 解释: 该图展示了 1x6x6 视觉词元在两种不同置换方式下的排列。左侧为行主序 (Rowmajor Order),词元按行优先顺序排列。右侧为希尔伯特曲线 (HilbertCurve) 置换,词元沿着希尔伯特曲线进行排列。图中的颜色深浅可能表示词元之间的相似性或某种属性,而箭头则指示了词元的连接顺序。希尔伯特曲线置换的优势在于,它能更好地将空间上接近的词元在序列中也保持接近,从而在视觉数据中增强相邻词元的相似度,这有利于稀疏注意力的性能。

5. 实验设置

5.1. 数据集

实验在语言、图像和视频生成等多样化模型上进行验证。

  • 语言模型 (Llama3.1 8B):
    • WikiText (Merity et al., 2017): 用于评估模型的困惑度 (perplexity, PPL),衡量模型对文本序列的预测置信度。
    • Longbench (Bai et al., 2024) 和 En.MC of InfiniteBench (Zhang et al., 2024): 用于全面评估模型的长上下文理解能力。
    • Needle-in-a-Haystack (Kamradt, 2023): 用于评估模型的检索能力,即在超长文本中找出关键信息的能力。
  • 文本到视频模型:
    • CogvideoX (2B), Mochi (Team, 2024), Open-Sora-Plan (Lin et al., 2024): 评估这些模型在 open-sora (Zheng et al., 2024c) 提示集上的视频生成质量。
  • 文本到图像模型:
    • Flux (.1-dev) (Black Forest Labs, 2023), Stable-Diffusion3.5 (large) (Stability AI, 2023): 评估这些模型在 COCO 标注 (Lin et al., 2014) 上的图像生成质量。

5.2. 评估指标

对论文中使用的每个评估指标进行详细说明。

  • 速度 (Speed): 1/t1/t

    • 概念定义: 衡量注意力计算效率的指标,表示单位时间内完成的注意力操作量。其中 O(attn)O(attn) 是标准注意力计算的总操作数,tt 是从给定 (Q, K, V) 到注意力输出的延迟时间(秒)。该指标能够公平地衡量包括稀疏区域预测时间在内的总时间开销。
    • 数学公式: Speed=O(attn)t \text{Speed} = \frac{O(\text{attn})}{t}
    • 符号解释:
      • O(attn)O(\text{attn}): 表示在标准(全)注意力计算中总的操作数。
      • tt: 表示从输入 Q, K, V 到注意力输出的实际延迟时间,以秒为单位。
  • 稀疏性 (Sparsity):

    • 概念定义: 稀疏性定义为被跳过的矩阵乘法 QiKjQ_i K_j^\topP~ijVj\widetilde{P}_{ij} V_j 的比例,相对于在全注意力计算中所需的总乘法量。更高的稀疏性意味着更多的计算被跳过,从而可能带来更大的加速。
    • 数学公式: Sparsity=Number of Skipped (QiKj+P~ijVj) operationsTotal Number of (QiKj+P~ijVj) operations for Full Attention \text{Sparsity} = \frac{\text{Number of Skipped } (Q_i K_j^\top + \widetilde{P}_{ij} V_j) \text{ operations}}{\text{Total Number of } (Q_i K_j^\top + \widetilde{P}_{ij} V_j) \text{ operations for Full Attention}}
    • 符号解释:
      • Number of Skipped (QiKj+P~ijVj) operations\text{Number of Skipped } (Q_i K_j^\top + \widetilde{P}_{ij} V_j) \text{ operations}: 通过稀疏机制被跳过的 QiKjQ_i K_j^\top 乘法和 P~ijVj\widetilde{P}_{ij} V_j 乘法的总数量。
      • Total Number of (QiKj+P~ijVj) operations for Full Attention\text{Total Number of } (Q_i K_j^\top + \widetilde{P}_{ij} V_j) \text{ operations for Full Attention}: 在不进行任何稀疏化的情况下,全注意力计算所需的 QiKjQ_i K_j^\top 乘法和 P~ijVj\widetilde{P}_{ij} V_j 乘法的总数量。
  • 困惑度 (Perplexity, PPL):

    • 概念定义: 在语言建模任务中,困惑度是衡量一个概率模型对样本或测试集预测好坏的指标。PPL 越低,表示模型预测的下一个词元越准确,模型的语言生成质量越高。它通常被定义为每个词元平均分支因子的几何平均。
    • 数学公式: PPL(W)=exp(1Ni=1NlogP(wiw1,,wi1)) \text{PPL}(W) = \exp\left(-\frac{1}{N} \sum_{i=1}^N \log P(w_i | w_1, \dots, w_{i-1})\right) 或者,等价于: PPL(W)=(i=1N1P(wiw1,,wi1))1N \text{PPL}(W) = \left( \prod_{i=1}^N \frac{1}{P(w_i | w_1, \dots, w_{i-1})} \right)^{\frac{1}{N}}
    • 符号解释:
      • W=w1,w2,,wNW = w_1, w_2, \dots, w_N: 一个包含 NN 个词元的序列或文本。
      • NN: 序列中词元的总数量。
      • P(wiw1,,wi1)P(w_i | w_1, \dots, w_{i-1}): 语言模型在给定前 i-1 个词元的情况下,预测第 ii 个词元的概率。
  • Longbench score / InfiniteBench:

    • 概念定义: LongbenchInfiniteBench 是专门设计用于评估大型语言模型 (LLMs) 长上下文理解能力的基准测试。它们包含一系列任务,要求模型处理和推理超长输入序列。分数通常是这些任务的平均性能指标,值越高表示长上下文理解能力越强。
  • 检索准确率 (Retrieval Accuracy, NIAH):

    • 概念定义:Needle-in-a-Haystack (大海捞针) 任务中,模型被要求从一个很长的、通常不相关的文本(干草堆)中检索出一个特定的、嵌入在其中的信息(针)。检索准确率衡量模型能否成功地找出这根“针”,即模型能否在长上下文环境中进行精确的局部信息检索。准确率通常表示为正确检索到的样本数占总样本数的比例。
    • 数学公式: Accuracy=Number of Correct RetrievalsTotal Number of Samples \text{Accuracy} = \frac{\text{Number of Correct Retrievals}}{\text{Total Number of Samples}}
    • 符号解释:
      • Number of Correct Retrievals\text{Number of Correct Retrievals}: 模型在测试集中正确检索出“针”的次数。
      • Total Number of Samples\text{Total Number of Samples}: 测试集中的总样本数。
  • CLIPSIM (CLIP Similarity):

    • 概念定义: 在文本到视频或文本到图像生成任务中,CLIPSIM 衡量生成内容(视频或图像)与输入文本提示之间的语义相似度。它通常通过使用预训练的 CLIP (Contrastive Language-Image Pre-training) 模型提取文本和视觉内容的特征,然后计算这些特征向量之间的余弦相似度来得到。值越高表示生成内容与文本描述越匹配。
    • 数学公式: CLIPSIM=1Tt=1Tcosine_similarity(CLIP_text_feature(P),CLIP_image_feature(Vt))\text{CLIPSIM} = \frac{1}{T} \sum_{t=1}^T \text{cosine\_similarity}(\text{CLIP\_text\_feature}(P), \text{CLIP\_image\_feature}(V_t)) 对于图像生成,T=1,即: CLIPSIM(I,P)=cosine_similarity(CLIP_image_feature(I),CLIP_text_feature(P))\text{CLIPSIM}(I, P) = \text{cosine\_similarity}(\text{CLIP\_image\_feature}(I), \text{CLIP\_text\_feature}(P))
    • 符号解释:
      • PP: 文本提示 (Prompt)。
      • VtV_t: 视频的第 tt 帧。
      • TT: 视频的总帧数。
      • II: 生成的图像。
      • CLIP_text_feature()\text{CLIP\_text\_feature}(\cdot): CLIP 模型提取文本特征的函数。
      • CLIP_image_feature()\text{CLIP\_image\_feature}(\cdot): CLIP 模型提取图像特征的函数。
      • cosine_similarity(a,b)=abab\text{cosine\_similarity}(\mathbf{a}, \mathbf{b}) = \frac{\mathbf{a} \cdot \mathbf{b}}{\Vert \mathbf{a} \Vert \Vert \mathbf{b} \Vert}: 两个向量的余弦相似度。
  • CLIP-T (CLIP Temporal):

    • 概念定义: 专门用于视频生成,衡量视频帧之间的时间一致性或平滑度。它通过计算视频中相邻帧的 CLIP 图像特征之间的余弦相似度来评估。值越高表示视频在时间上的连贯性越好。
    • 数学公式: CLIP-T=1T1t=1T1cosine_similarity(CLIP_image_feature(Vt),CLIP_image_feature(Vt+1))\text{CLIP-T} = \frac{1}{T-1} \sum_{t=1}^{T-1} \text{cosine\_similarity}(\text{CLIP\_image\_feature}(V_t), \text{CLIP\_image\_feature}(V_{t+1}))
    • 符号解释:
      • VtV_t: 视频的第 tt 帧。
      • TT: 视频的总帧数。
      • CLIP_image_feature()\text{CLIP\_image\_feature}(\cdot): CLIP 模型提取图像特征的函数。
      • cosine_similarity(,)\text{cosine\_similarity}(\cdot, \cdot): 余弦相似度函数。
  • VQA-a (Video Quality Assessment - aesthetic):

    • 概念定义: 视频美学质量评估,衡量视频在视觉上是否具有美感、吸引力。通常通过人工标注或预训练的美学质量预测模型来获得分数。值越高表示美学质量越好。没有统一的数学公式,通常是模型输出的评分。
  • VQA-t (Video Quality Assessment - technical):

    • 概念定义: 视频技术质量评估,衡量视频在技术层面(如清晰度、稳定性、无伪影)的质量。同样通过人工标注或技术质量预测模型来获得分数。值越高表示技术质量越好。没有统一的数学公式,通常是模型输出的评分。
  • Flow-score (FScore):

    • 概念定义: 衡量视频时间一致性,特别关注视频中运动的平滑性和真实感。它通常基于光流 (optical flow) 估计,通过量化帧间运动向量的连贯性来评估。值越高表示时间一致性越好。
    • 数学公式: Flow-score 的具体计算可能因不同的光流算法和聚合方式而异,论文未给出具体公式。一般而言,它会涉及计算帧间的运动向量场,并评估这些向量场的平滑度或误差。
    • 符号解释: FScore 是一个复合指标,用于量化视频中运动的流畅性和合理性,分数越高表示时间连贯性越好。
  • FID (Frechet Inception Distance):

    • 概念定义: 在图像生成任务中,FID 是一种衡量生成图像质量和多样性的指标。它通过比较生成图像和真实图像在预训练 Inception-v3 模型中间层的特征分布来计算。FID 值越低,表示生成图像的质量越高,与真实图像的分布越接近。
    • 数学公式: FID=μrμg22+Tr(Σr+Σg2(ΣrΣg)1/2) \text{FID} = \Vert \mu_r - \mu_g \Vert^2_2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2})
    • 符号解释:
      • μr\mu_r: 真实图像在 Inception-v3 特征空间中的均值向量。
      • μg\mu_g: 生成图像在 Inception-v3 特征空间中的均值向量。
      • Σr\Sigma_r: 真实图像特征的协方差矩阵。
      • Σg\Sigma_g: 生成图像特征的协方差矩阵。
      • 22\Vert \cdot \Vert^2_2: 向量的 L2 范数的平方。
      • Tr()\text{Tr}(\cdot): 矩阵的迹(对角线元素之和)。
      • (ΣrΣg)1/2(\Sigma_r \Sigma_g)^{1/2}: 矩阵乘积的平方根。
  • ImageReward (IR):

    • 概念定义: 在文本到图像生成中,ImageReward 是一个旨在衡量生成图像符合人类偏好程度的指标。它通过一个在人类反馈数据上训练的模型来评估图像的“好坏”,从而量化生成图像对人类审美的吸引力。值越高表示图像更符合人类偏好。没有统一的数学公式,通常是模型输出的评分。

5.3. 对比基线

论文将 SpargeAttn 与以下方法进行了比较:

  • Full-Attention: 即标准的、未进行稀疏化的全注意力机制,作为性能和质量的基准。
  • MInference (Jiang et al., 2024): 一种动态稀疏方法,通过块稀疏化来加速长上下文 LLM 的预填充。论文使用其 30% 和 70% 的稀疏度设置进行对比。
  • FlexPrefill (Lai et al., 2025): 另一种上下文感知的稀疏注意力机制,用于高效的长序列推理。论文使用其论文中推荐的 γ=0.95\gamma = 0.95γ=0.99\gamma = 0.99 参数设置进行对比,这会产生不同的稀疏度。

实现和超参数:

  • SpargeAttn 使用 CUDA 实现。
  • 超参数 l1,l2l_1, l_2 用于控制 L1 误差阈值,用于指导超参数 τ,θ,λ\tau, \theta, \lambda 的确定:
    • Llama3.1: (l1=0.08,l2=0.09)(l_1 = 0.08, l_2 = 0.09)
    • CogvideoX, Mochi: (l1=0.05,l2=0.06)(l_1 = 0.05, l_2 = 0.06)
    • Stable-Diffusion3.5, Flux: (l1=0.07,l2=0.08)(l_1 = 0.07, l_2 = 0.08)
    • Open-Sora-Plan: (l1=0.03,l2=0.035)(l_1 = 0.03, l_2 = 0.035)

6. 实验结果与分析

6.1. 核心结果分析

6.1.1. 端到端指标和速度

以下是原文 Table 1 的结果,展示了 SpargeAttn 在文本、图像和视频生成模型上的端到端性能和加速效果。

Model (seq_len) Attention (Sparsity) Speed (1/t)↑ Llama3.1 (128K)
WikiText (Ppl.) ↓ Longbench ↑ InfiniteBench ↑ NIAH↑
Llama3.1(128K) Full-Attention 156.9 6.013 38.682 0.6594 0.907
Minference (0.5) 140.1 10.631 28.860 0.5152 0.832
FlexPrefill (0.5) 240.6 6.476 38.334 0.6460 0.858
Minference (0.3) 115.7 6.705 34.074 0.6532 0.870
FlexPrefill (0.42) 206.9 6.067 38.334 0.6581 0.878
SpargeAttn (0.54) 708.1 6.020 39.058 0.6638 0.909
  • 分析 Llama3.1 (128K):

    • SpargeAttn (稀疏度 0.54) 在 Llama3.1 上实现了高达 708.1Speed (1/t),远超 Full-Attention (156.9) 和所有基线方法。

    • 同时,SpargeAttnWikiText (Ppl.) (6.020)、Longbench (39.058)、InfiniteBench (0.6638) 和 NIAH (0.909) 指标都与 Full-Attention 几乎持平甚至略有提升,表明在大幅加速的同时没有牺牲模型质量。

    • 相比之下,MInferenceFlexPrefill 在实现一定稀疏度时,往往导致 PPL 显著升高或 NIAH 下降,表明其在长上下文任务中对模型性能有负面影响。SpargeAttn 的优势在于其在保持高精度的同时实现了高加速。

      Model (seq_len) Attention (Sparsity) Speed (1/t)↑ CogvideoX (17K)
      CLIPSIM ↑ CLIP-T ↑ VQA-a ↑ VQA-t ↑ FScore ↑
      CogvideoX(17K) Full-Attention 166.0 0.1819 0.9976 80.384 75.946 5.342
      Minference (0.5) 264.6 0.1728 0.9959 70.486 62.410 2.808
      FlexPrefill (0.6) 175.3 0.1523 0.9926 1.5171 4.5034 1.652
      Minference (0.3) 196.9 0.1754 0.9964 77.326 63.525 3.742
      FlexPrefill (0.45) 142.0 0.1564 0.9917 7.7259 8.8426 2.089
      SpargeAttn (0.46) 507.9 0.1798 0.9974 78.276 74.846 5.030
  • 分析 CogvideoX (17K):

    • SpargeAttn (稀疏度 0.46) 在 CogvideoX 上实现了 507.9Speed (1/t),显著高于 Full-Attention (166.0) 和所有基线。

    • CLIPSIM (0.1798)、CLIP-T (0.9974)、VQA-a (78.276)、VQA-t (74.846) 和 FScore (5.030) 等视频质量指标均与 Full-Attention 接近,而基线方法 FlexPrefill 在高稀疏度时 (FlexPrefill (0.6)) 甚至导致 VQA-aVQA-tFScore 严重下降,表明生成视频质量大幅劣化。

      Model (seq_len) Attention (Sparsity) Speed (1/t)↑ Mochi (22K)
      CLIPSIM ↑ CLIP-T ↑ VQA-a ↑ VQA-t ↑ FScore ↑
      Mochi(22K) Full-Attention 164.2 0.1725 0.9990 56.472 67.663 1.681
      Minference (0.5) 202.4 0.1629 0.9891 6.668 50.839 0.653
      FlexPrefill (0.48) 191.3 0.1667 0.9898 0.582 0.0043 X
      Minference (0.3) 147.7 0.1682 0.9889 14.541 42.956 0.833
      FlexPrefill (0.4) 171.7 0.1677 0.9909 2.941 0.7413 X
      SpargeAttn (0.47) 582.4 0.1720 0.9990 54.179 67.219 1.807
  • 分析 Mochi (22K):

    • SpargeAttn (稀疏度 0.47) 在 Mochi 上表现出惊人的 582.4 Speed (1/t),是 Full-Attention (164.2) 的约 3.5 倍,且远超所有基线。

    • 视频质量指标方面,SpargeAttnCLIPSIM (0.1720)、CLIP-T (0.9990)、VQA-a (54.179)、VQA-t (67.219) 和 FScore (1.807) 均与 Full-Attention 保持一致。

    • MInferenceFlexPrefillMochi 上表现不佳,特别是 FlexPrefillFScore 无法计算 (XX),且 VQA-aVQA-t 严重受损,表明其无法维持生成质量。

      Model (seq_len) Attention (Sparsity) CLIPSIM ↑ CLIP-T ↑ VQA-a ↑ VQA-t ↑ FScore ↑ Latency ↓
      Open-Sora-Plan (38K) Full-Attention 0.1650 0.9994 81.40 80.60 0.847 629s
      SpargeAttn (0.34) 0.1686 0.9985 77.59 76.91 0.839 393s
  • 分析 Open-Sora-Plan (38K):

    • SpargeAttn (稀疏度 0.34) 将 Open-Sora-Plan 的延迟从 629s 降低到 393s,实现了显著加速。

    • 视频质量指标 CLIPSIM, CLIP-T, VQA-a, VQA-t, FScore 均保持在与 Full-Attention 相近的水平。

      Model (seq_len) Attention (Sparsity) Speed (1/t)↑ Flux (4.5K)
      FID ↓ CLIP ↑ IR ↑
      Flux (4.5K) Full-Attention 158.2 166.103 31.217 0.8701
      Minference (0.5) 151.8 180.650 30.235 0.4084
      FlexPrefill (0.48) 47.7 443.928 18.3377 -2.2657
      Minference (0.3) 118.9 170.221 31.001 0.7701
      FlexPrefill (0.41) 40.9 405.043 19.5591 -2.2362
      SpargeAttn (0.38) 280.3 163.982 31.448 0.9207
  • 分析 Flux (4.5K):

    • SpargeAttn (稀疏度 0.38) 达到 280.3 Speed (1/t),显著高于 Full-Attention (158.2)。

    • 图像质量指标 FID (163.982)、CLIP (31.448)、IR (0.9207) 均优于 Full-Attention,表明 SpargeAttn 在某些情况下甚至能带来轻微的质量提升。

    • 基线方法在 Flux 上表现非常差,FlexPrefillFID 值高达 400+,IR 甚至为负值,质量严重受损。

      Model (seq_len) Attention (Sparsity) Speed (1/t)↑ Stable-Diffusion3.5 (4.5K)
      FID ↓ CLIP ↑ IR ↑
      Stable-Diffusion3.5 (4.5K) Full-Attention 164.2 166.101 32.007 0.9699
      Minference (0.5) 186.4 348.930 18.3024 -2.2678
      FlexPrefill (0.37) 23.1 350.497 18.447 -2.2774
      Minference (0.3) 150.3 337.530 18.099
      FlexPrefill (0.35) 22.7 348.612 18.147 -2.2647
      SpargeAttn (0.31) 293.0 166.193 32.114 0.9727
  • 分析 Stable-Diffusion3.5 (4.5K):

    • SpargeAttn (稀疏度 0.31) 达到 293.0 Speed (1/t),显著高于 Full-Attention (164.2)。

    • 图像质量指标 FID (166.193)、CLIP (32.114)、IR (0.9727) 均与 Full-Attention 相当或略优。

    • Flux 类似,基线方法在此模型上表现极差,FID 显著升高,CLIPIR 严重下降。

      总结: 综合 Table 1SpargeAttn 在各种模型和任务中,均能在保持甚至略微提升端到端性能指标的同时,显著提高注意力计算速度,表现出其通用性 (universality)高效性 (efficiency)。现有稀疏注意力基线方法通常难以在通用模型上兼顾性能和效率。

内核速度比较 (Kernel Speed Comparison): 下图 Figure 10 比较了不同方法在不同稀疏度下的内核速度。

Figure 10. Kernel speed comparison under varying sparsity. Input tensors have a sequence length of 22K and a head dimension of 128. SpargeAttn \(+ F A 2\) means deploying our method on FlashAttention2.
该图像是图表,展示了在不同稀疏度下的内核速度比较。图中展示了包括 SpargeAttn+FA2、Sage 和 Sage2 等不同方法的速度表现,横轴为稀疏度,纵轴为速度(1/t)。

Figure 10. Kernel speed comparison under varying sparsity. Input tensors have a sequence length of 22K and a head dimension of 128. SpargeAttn +FA2+ F A 2 means deploying our method on FlashAttention2.

  • 图 10 分析:
    • 该图展示了在序列长度为 22K、头维度为 128 的输入下,不同稀疏注意力实现(Sage, Sage2, SpargeAttn+FA2SpargeAttn+FA2, Full Attention (FA2))的内核速度(Speed (1/t))随稀疏度变化的趋势。
    • Full Attention (FA2) 的速度保持不变,因为它没有稀疏化。
    • SpargeAttn+FA2SpargeAttn+FA2 在所有稀疏度下都表现出最高的内核速度,且随着稀疏度增加,速度提升更为显著。这表明 SpargeAttn 的稀疏化机制非常高效,并且能很好地与 FlashAttention2 结合。
    • SageSage2 作为量化注意力方法,也比 Full Attention (FA2) 快,但其速度提升不如 SpargeAttn+FA2SpargeAttn+FA2 明显,且随着稀疏度增加,SpargeAttn+FA2SpargeAttn+FA2 的优势进一步扩大。
    • 这验证了 SpargeAttn 提出的稀疏方法在底层内核计算上的效率。

端到端加速 (End-to-end Speedup): 以下是原文 Table 2 的结果,展示了 SpargeAttn 在几个模型上的端到端推理延迟。

Model GPU Original SageAttn SpargeAttn
CogvideoX RTX4090 87 s 68 s 53 s
Mochi L40 1897 s 1544 s 1037 s
Llama3.1 (24K) RTX4090 4.01 s 3.53 s 2.6 s
Llama3.1(128K) L40 52 s 42s 29.98 s
  • 分析 Table 2:
    • SpargeAttn 显著降低了所有测试模型的端到端生成延迟。
    • Mochi 模型上,SpargeAttn 将延迟从 Original (可能是 Full-Attention 或未优化基线) 的 1897 秒降低到 1037 秒,实现了约 1.83x 的加速 (189710371.83\frac{1897}{1037} \approx 1.83)。
    • Llama3.1 (128K 序列长度) 上,延迟从 52 秒降低到 29.98 秒,实现了约 1.73x 的加速。
    • 这表明 SpargeAttn 不仅在注意力内核层面高效,在整个模型的推理流程中也能带来可观的实际加速。与 SageAttn 相比,SpargeAttn (结合了 SageAttn 的量化能力) 进一步提升了速度。

6.1.2. 可视化比较

以下是原文 Figure 1 的结果,展示了 SpargeAttnMochi 模型上的视觉质量对比。

Figure 1. SpargeAttn can achieve \(1 . 8 3 \\mathrm { x }\) speedup on Mochi on L40 GPU, with no video quality loss.
该图像是一个比较图,展示了SpargeAttn与全注意力机制的性能对比。全注意力的端到端时间为1897秒,而SpargeAttn的端到端时间仅为1037秒,实现了1.83倍的加速。图中显示了两种方法处理的视频质量未受损失。

Figure 1. SpargeAttn can achieve 1.83mathrmx1 . 8 3 \\mathrm { x } speedup on Mochi on L40 GPU, with no video quality loss.

  • 图 1 分析: 该图直观地展示了 SpargeAttnMochi 模型上实现 1.83 倍加速的同时,视频质量与 Full-Attention 相比没有损失 (no video quality loss)。左侧是 Full-Attention 生成的视频帧,右侧是 SpargeAttn 生成的视频帧。从视觉上看,两者的清晰度、细节和色彩几乎无法区分,证明了 SpargeAttn 在保持高质量生成方面的有效性。

    以下是原文 Figure 6 的结果,展示了 SpargeAttnCogvideoX 模型上的视觉质量对比。

    Figure 6. Visible examples on CogvideoX using SpargeAttention. 该图像是一个示意图,展示了使用 SpargeAttention 的结果对比,左侧为 Full Attention,右侧为 Sparge Attention。图中包含交通、自然风景和模型的不同生成效果,表明 SpargeAttention 在多种场景下的应用效果。

Figure 6. Visible examples on CogvideoX using SpargeAttention.

  • 图 6 分析: 该图展示了 CogvideoX 模型在 Full AttentionSpargeAttention 下的生成结果对比。左侧是 Full Attention 的输出,右侧是 SpargeAttention 的输出。在所有三个示例(交通场景、自然风景、模型)中,SpargeAttention 生成的图像质量与 Full Attention 几乎相同,细节、色彩和构图都得到了很好的保留。这进一步支持了 SpargeAttn 在视频生成任务中不牺牲视觉质量的结论。

    以下是原文 Figure 7 的结果,展示了 SpargeAttnFluxStable-Diffusion3.5 模型上的视觉质量对比。

    Figure 7. Comparison examples on Flux and Stable-Diffusion3.5. The sparsity of SpargeAttn, MInference and FlexPrefill is 0.38, 0.3, and 0.4 on F 1ux and 0.31, 0.3, and 0.35 on Stable-Diffusion3.5. 该图像是一个比较示意图,展示了在Flux和Stable-Diffusion3.5上使用Full Attention、SpargeAttention、MInference和FlexPrefill的结果。不同模型的输出在视觉效果上有明显差异,SpargeAttention在保留细节的同时有效减少了计算复杂度。

Figure 7. Comparison examples on Flux and Stable-Diffusion3.5. The sparsity of SpargeAttn, MInference and FlexPrefill is 0.38, 0.3, and 0.4 on F 1ux and 0.31, 0.3, and 0.35 on Stable-Diffusion3.5.

  • 图 7 分析: 该图对比了 FluxStable-Diffusion3.5 模型在 Full AttentionSpargeAttnMInferenceFlexPrefill 下的图像生成结果。
    • Flux 模型 (第一行): Full AttentionSpargeAttn 生成的图像(右二)质量都很高,细节丰富。而 MInference (右三) 和 FlexPrefill (右一) 生成的图像则出现明显的伪影和失真,质量显著下降。

    • Stable-Diffusion3.5 模型 (第二行): 类似地,Full AttentionSpargeAttn 生成的图像(右二)清晰且具有细节。而 MInference (右三) 和 FlexPrefill (右一) 的图像质量较差,颜色失真且细节模糊。

    • 这些可视化结果进一步强调了 SpargeAttn 在保持高质量生成方面的鲁棒性,以及现有稀疏注意力基线在通用模型上可能带来的严重质量退化。

      以下是原文 Figure 9 的结果,展示了 SpargeAttnLlama3.1 模型上的长上下文检索任务对比。

      Figure 9. A Needle-in-a-Haystack comparison example on Llama3.1. The sparsity of SpargeAttn, MInference, and FlexPrefill is 0.5, 0.5, and 0.54. 该图像是一个比较不同注意力机制性能的热图,展示了 Full Attention、SpargeAttention、MInference 和 FlexPrefill 模型在不同 Token 限制下的深度 (%) 和得分。SpargeAttention 的整体得分为 0.909,表现优于其他模型。

Figure 9. A Needle-in-a-Haystack comparison example on Llama3.1. The sparsity of SpargeAttn, MInference, and FlexPrefill is 0.5, 0.5, and 0.54.

  • 图 9 分析: 该图展示了 Llama3.1 模型在 Needle-in-a-Haystack (NIAH) 任务中的性能对比。NIAH 任务评估模型在长文本中定位和提取特定信息的能力。
    • Full AttentionSpargeAttn 在整个序列长度范围内都保持了较高的检索性能(条形图的高度)。
    • MInferenceFlexPrefill 的性能在某些 token 限制下出现明显下降,表明它们在处理长上下文时容易丢失关键信息。
    • 这支持了 SpargeAttn 在不牺牲长上下文理解能力的前提下实现加速的结论,甚至在某些情况下,通过稀疏化帮助模型更好地聚焦相关信息,从而略微提升了性能。

6.1.3. SpargeAttn 增强 LLM 性能

Table 1Figure 9 中可以观察到,SpargeAttn 在长上下文任务中增强了 LLM 的性能。例如,在 Llama3.1(128K)Llama3.1 (128K)LongbenchNIAH 任务中,SpargeAttn 的表现略优于 Full-Attention。作者推测,这种改进可能源于稀疏注意力帮助 LLM 专注于更相关的信息,减少了对不重要信息的处理,从而提高了长上下文的理解和检索能力。

6.2. 消融实验/参数分析

6.2.1. 稀疏块预测的开销

以下是原文 Table 3 的结果,展示了稀疏块预测的开销。

Sequence Len Prediction (ms) Full Attention (ms) Overhead
8k 0.251 6.649 3.78%
16k 0.487 26.83 1.82%
32k 0.972 106.68 0.911%
64k 2.599 424.24 0.612%
128k 8.764 1696.2 0.516%
  • 分析 Table 3:
    • 该表比较了 SpargeAttn 中动态稀疏块预测的开销与注意力执行延迟。
    • 随着序列长度的增加,预测开销的绝对时间(Prediction (ms))虽然增加,但相对于 Full Attention 的总执行时间(Full Attention (ms))的占比 (Overhead) 却显著下降。
    • 在 8K 序列长度下,预测开销占 3.78%。但在 128K 序列长度下,预测开销仅占 Full Attention 延迟的 0.516%
    • 这表明 SpargeAttn 的预测机制非常高效,对于长序列场景,其引入的额外开销几乎可以忽略不计,从而能够实现净加速。

6.2.2. 希尔伯特曲线置换 (HilbertCurve Permutation) 的影响

以下是原文 Table 4 的结果,展示了不同置换方法对稀疏度和精度的影响。

Method Sim-q ↑ Sim-k ↑ L1 ↓ Sparsity ↑
Random 0.321 0.019 0.0414 0.048
Rowmajor 0.551 0.390 0.0307 0.363
Timemajor 0.514 0.367 0.0342 0.338
HilbertCurve 0.572 0.479 0.0389 0.392
  • 分析 Table 4:
    • 该表比较了不同词元置换方法(Random, Rowmajor, Timemajor, HilbertCurve)对 Mochi 模型中 Sim-qQQ 的平均块自相似度)、Sim-kKK 的平均块自相似度)、L1 误差和稀疏性的影响。
    • Random 置换的自相似度最低,稀疏性也最低。
    • RowmajorTimemajor 置换相比 Random 有显著提升。
    • HilbertCurve 置换Sim-q (0.572) 和 Sim-k (0.479) 上均实现了最高的块自相似度。
    • 尽管 HilbertCurveL1 误差略高于 Rowmajor,但其实现的稀疏性 (0.392) 是所有方法中最高的。
    • 这验证了 HilbertCurve 置换能有效提高块的自相似度,从而提升稀疏性,仅带来可接受的微小精度差异。这对于视觉模型利用空间局部性非常有效。

6.2.3. 自相似度判断 (Self-Similarity Judge) 的消融研究

以下是原文 Table 5 的结果,展示了自相似度判断对模型性能的影响。

Method VQA-a ↑ VQA-t ↑ FScore ↑
W/o. self-sim Judge 34.664 44.722 1.138
With self-sim Judge 54.179 67.219 1.807
  • 分析 Table 5:
    • 该表比较了在 Mochi 模型上,有无自相似度判断 (Self-Similarity Judge) 对视频质量指标(VQA-a, VQA-t, FScore)的影响。
    • With self-sim Judge (带自相似度判断) 的方法在所有指标上都显著优于 W/o. self-sim Judge (不带自相似度判断)。例如,VQA-a 从 34.664 提升到 54.179,FScore 从 1.138 提升到 1.807。
    • 这强调了选择性词元压缩 (Selective Token Compression) 策略的必要性:通过判断块的自相似性,并强制计算非自相似块(fix blocks)的注意力,能够有效避免关键信息丢失,从而保证了端到端模型性能。如果没有这个判断,激进的稀疏化会导致模型性能大幅下降。

6.2.4. MgM_gMpvM_{pv} 带来的稀疏性分析

以下是原文 Table 6 的结果,展示了不同稀疏策略的贡献。

Strategy only Mg only Mpv Mg +Mpv
Sparsity 51.2% 27.7% 54%
  • 分析 Table 6:
    • 该表分析了在 Llama3.1 (128K 序列长度,Needle-in-a-Haystack 任务) 上,第一阶段掩码 MgM_g 和第二阶段掩码 MpvM_{pv} 对稀疏度的贡献。
    • 单独使用 MgM_g(第一阶段过滤)可以实现 51.2% 的稀疏性,这是主要的稀疏贡献来源。
    • 单独使用 MpvM_{pv}(第二阶段过滤)可以实现 27.7% 的稀疏性。
    • 结合 MgM_gMpvM_{pv} (Mg+MpvMg + Mpv) 能够实现 54% 的总稀疏性。
    • 这表明两阶段过滤协同工作,MgM_g 提供了大部分的粗粒度稀疏化,而 MpvM_{pv} 在此基础上提供了额外的、细粒度的稀疏化,进一步提升了总稀疏度。

6.2.5. 稀疏性随序列长度增加

以下是原文 Table 7 的结果,展示了稀疏性与序列长度的关系。

Sequence Len 8K 16K 24K 48K 128K
Sparsity 6.8% 26.4% 35.7% 49.8% 54%
  • 分析 Table 7:
    • 该表展示了在 Llama3.1 模型上,在恒定的精度约束下,注意力稀疏性随序列长度的增加而增加。
    • 从 8K 序列长度的 6.8% 稀疏性,一直增加到 128K 序列长度的 54% 稀疏性。
    • 这个趋势表明,对于更长的上下文,注意力图通常会变得更加稀疏,这意味着 SpargeAttn 在处理超长序列时能够实现更高的稀疏度和更大的加速潜力。

6.2.6. 扩散模型稀疏性分析 (Appendix A.4)

原文 Appendix A.4 提供了 CogvideoX 扩散模型在不同维度上的稀疏性分析(图 14-17)。

  • 层稀疏性 (Layer-wise Sparsity, Figure 14):

    • 下图是原文 Figure 14 的结果。

      Figure 14. Layer-wise sparsity of CogvideoX. 该图像是一个图表,展示了CogvideoX模型在不同层级的稀疏性分析。X轴为层索引,Y轴为平均稀疏度,红色虚线表示全局平均值0.27。不同层的稀疏度值变化显著,L5层的稀疏度最高,达到0.60。

    Figure 14. Layer-wise sparsity of CogvideoX.

    • 该图展示了 CogvideoX 模型在不同层(Layer Index)的平均稀疏度。稀疏度在不同层之间变化显著,例如,在第 5 层稀疏度最高(约 0.60),而在其他层稀疏度较低。这表明不同层的注意力机制有不同的稀疏模式,因此针对每个层设置不同的超参数是必要的,以实现最佳的稀疏化效果。
  • 时间步稀疏性 (Timestep-wise Sparsity, Figure 15):

    • 下图是原文 Figure 15 的结果。

      Figure 15. Timestep-wise sparsity of CogvideoX. 该图像是一个图表,展示了随着时间步的增加,平均稀疏度的变化。图中可以看到,稀疏度从0.2逐渐增加到接近0.3,说明随着时间步的推移,注意力机制的稀疏性在增强。

    Figure 15. Timestep-wise sparsity of CogvideoX.

    • 该图展示了 CogvideoX 模型中,稀疏度随采样时间步(Timestep)的变化。对于扩散模型,稀疏性随着时间步的增加而增加。这可能意味着在扩散过程的早期(噪声较多),模型需要关注更多的信息;而在后期(细节逐渐清晰),模型可以更聚焦于关键信息,从而可以进行更激进的稀疏化。
  • 样本稀疏性 (Sample-wise Sparsity, Figure 16):

    • 下图是原文 Figure 16 的结果。

      Figure 16. Sample-wise sparsity of CogvideoX. 该图像是一个图表,展示了不同提示的稀疏性(跨时间步、块、批次和头部的均值)。横轴为提示索引,纵轴为平均稀疏度,显示了各个提示的稀疏程度与整体均值0.271的对比。

    Figure 16. Sample-wise sparsity of CogvideoX.

    • 该图展示了在 CogvideoX 模型中,不同提示或样本(Prompt Index)的稀疏度(已在时间步、块、批次和头部上取平均)。稀疏度在不同样本之间也有所不同,但大部分样本的稀疏度围绕全局平均值 0.271 波动,表明其具有一定的稳定性,但也存在个体差异。
  • 头稀疏性 (Head-wise Sparsity, Figure 17):

    • 下图是原文 Figure 17 的结果。

      Figure 17. Head-wise sparsity of CogvideoX. 该图像是图表,展示了CogvideoX中多个层次的头稀疏性。每个图表对应一个层,表现出不同层次中的最大稀疏值,提供了对模型稀疏性特征的直观理解。

    Figure 17. Head-wise sparsity of CogvideoX.

    • 该图展示了 CogvideoX 模型在不同层中,每个注意力头(Head Index)的稀疏度。稀疏度在不同的注意力头之间也存在差异。这意味着某些注意力头可能比其他头更稀疏,进一步支持了为每个层和每个头设置不同超参数以优化稀疏化的策略。

      总结: 这些分析结果深入揭示了 SpargeAttn 方法的有效性、鲁棒性以及在不同模型和任务中的行为。消融实验验证了各个组件(如自相似度判断、两阶段过滤)的重要性,而稀疏性分析则为超参数设置和未来优化提供了指导。

6.3. 其他实验结果

以下是原文 Table 8 的结果,详细描述了不同置换方法。

Method Detailed Description
Random Random permutation of tokens, the order is recorded to perform inverse permutation.
Rowmajor Permutation following row-major order. Tokens are continuous along the W dimension.
Columnmajor Permutation following column-major order. Tokens are continuous along the H dimension.
Timemajor Permutation following time-major order. Tokens are continuous along the T dimension.
HilbertCurve Permutation following a Hilbert curve.

以下是原文 Table 9 的结果,详细展示了置换消融实验的结果。

Method Sim-q↑ Sim-k↑ Precision(L1)↓ Sparsity↑
CogvideoX Mochi CogvideoX Mochi CogvideoX Mochi CogvideoX Mochi
Random 0.502 0.321 0.025 0.019 0.0348 0.0414 0.027 0.048
Rowmajor 0.676 0.551 0.435 0.390 0.0265 0.0307 0.242 0.363
Columnmajor 0.633 0.547 0.335 0.394 0.0274 0.0342 0.198 0.366
Timemajor 0.692 0.514 0.479 0.367 0.0294 0.0342 0.238 0.338
HilbertCurve 0.709 0.572 0.523 0.479 0.0323 0.0389 0.265 0.392
  • 分析 Table 9: 该表详细展示了 Table 4 中提到的置换消融实验结果,包括 CogvideoXMochi 两个模型。

    • HilbertCurveCogvideoXMochi 上都实现了最高的 Sim-qSim-k (平均块自相似度),例如 CogvideoXSim-q 达到 0.709,Sim-k 达到 0.523。

    • 相应的,HilbertCurve 在两个模型上也都获得了最高的稀疏度(CogvideoX 为 0.265,Mochi 为 0.392)。

    • L1 误差(精度)方面,HilbertCurve 的值略高于 RowmajorColumnmajor,但仍保持在较低水平,例如 CogvideoX 为 0.0323,Mochi 为 0.0389。这说明 HilbertCurve 在提高稀疏性的同时,对精度影响很小。

    • Random 置换表现最差,自相似度低,稀疏度也最低。

    • 这进一步证实了 HilbertCurve 置换对于提高块自相似度和稀疏性的有效性。

      以下是原文 Table 10 的结果,展示了自相似度判断消融实验的 L1 误差和稀疏度。

      Method w/ judge w/o judge filter w/ judge filter w/o judge
      CogvideoX Mochi CogvideoX Mochi CogvideoX Mochi CogvideoX Mochi
      L1 error↓ 0.0316 0.0343 0.0325 0.0365 0.0843 0.0555 0.214 0.154
      Sparsity ↑ 0.199 0.301 0.203 0.305 0.242 0.371 0.275 0.392
  • 分析 Table 10: 该表进一步细化了自相似度判断的消融实验。

    • 前两列 (w/ judgew/o judge) 比较的是在完整 SpargeAttn 流程中,是否包含自相似度判断。结果显示,有判断时 L1 error 更低,稀疏度略低,这与 Table 5 的结论一致,即判断有助于精度。

    • 后两列 (filter w/ judgefilter w/o judge) 比较的是在稀疏性更高的设定下(可能是降低 τ\tauλ\lambda),自相似度判断的作用。在这里,filter w/o judge 导致 L1 error 显著增加(CogvideoX 从 0.0843 飙升到 0.214),而 Sparsity 略有提升。这明确表明,在追求更高稀疏度时,自相似度判断对于维持精度至关重要。没有这个判断,模型性能会严重退化。

      以下是原文 Table 11 的结果,展示了 Llama3.1 在 16-28K 序列长度下的 NIAH 任务结果。

      Model (seq_len) Attention (Sparsity) Speed (TOPS)↑ NIAH ↑
      Llama3.1 (24K) Full-Attention 156.9 0.838
      Minference (0.5) 122.5 0.635
      FlexPrefill (0.6) 179.6 0.776
      Minference (0.3) 102.3 0.652
      FlexPrefill (0.3) 117.6 0.797
      SpargeAttn (0.36) 443.6 0.863
  • 分析 Table 11:

    • 该表补充了 Llama3.1 在 24K 序列长度下 NIAH 任务的性能。
    • SpargeAttn (稀疏度 0.36) 再次展示了其卓越的性能,其 Speed (TOPS) 达到 443.6,远超 Full-Attention (156.9) 和所有基线。
    • NIAH 准确率上,SpargeAttn (0.863) 甚至略优于 Full-Attention (0.838),而基线方法 MInferenceFlexPrefill 的准确率均有不同程度的下降。
    • 这进一步支持了 SpargeAttn 不仅能加速,还能在长上下文任务中维持甚至提升 LLM 性能的观点。

7. 总结与思考

7.1. 结论总结

本文提出了 SpargeAttn,一种通用、准确且无需训练的稀疏注意力加速方案,适用于各种大型模型推理。其核心贡献在于:

  1. 两阶段在线过滤机制:
    • 第一阶段:选择性词元压缩,通过分析块内词元相似性并结合压缩注意力图,快速准确地预测稀疏块,有效跳过大量的 QKQK^\topP~V\widetilde{P}V 乘法。其关键在于对非自相似块进行强制计算,避免关键信息丢失。
    • 第二阶段:稀疏 Warp Online Softmax,在 GPU Warp 级别上,无额外开销地识别并跳过那些对最终输出贡献可忽略的 P~V\widetilde{P}V 乘积。
  2. 与量化技术的正交集成: SpargeAttn 无缝集成了 8 位量化的 SageAttention,实现了稀疏化和量化的叠加加速。
  3. HilbertCurve 置换优化: 针对视觉模型,通过 HilbertCurve 置换提高了相邻词元的相似度,进一步增强了稀疏化潜力。 实验结果在 Llama3.1 (语言)、CogvideoX, Mochi, Open-Sora-Plan (视频) 和 Flux, Stable-Diffusion3.5 (图像) 等多样化模型上,均验证了 SpargeAttn 能够在不牺牲端到端性能指标的前提下,实现 2.5x 到 5x 的显著推理加速,并且性能优于现有稀疏注意力基线。在某些长上下文 LLM 任务中,甚至能略微提升模型性能。

7.2. 局限性与未来工作

论文本身并未明确指出自身的局限性或未来的研究方向。然而,从其方法和实验设计中,我们可以推断出一些潜在的局限性和可以探索的未来工作:

  • 超参数调优的复杂性: SpargeAttn 引入了三个超参数 (τ,θ,λ\tau, \theta, \lambda),并且需要针对每个注意力层进行独立的网格搜索调优,以达到最佳的稀疏性和精度平衡。尽管论文称其过程“直接 (straightforward)”,但在实际部署到新模型或大规模模型时,这个过程可能仍然耗时且计算密集,尤其是在有大量层和头的情况下。
  • 自相似度度量(CosSim)的鲁棒性: CosSim 的定义 (mean(XXT/max(XXT))mean(XX^T / |max(XX^T)|)) 依赖于块内最大点积的绝对值进行归一化。这种归一化方式在某些边缘情况下(例如,所有点积都非常小但有一个异常值)可能不够鲁棒,或者其对不同数据分布的通用性有待进一步理论分析。
  • 通用性的边界: 尽管论文声称“加速任何模型推理”,并在多个模态上进行了验证,但“任何模型”的范围仍需更广泛的测试。例如,对于一些注意力模式非常密集或非常分散的模型,其稀疏化潜力可能有限,或者需要更复杂的预测策略。
  • 与训练的结合: SpargeAttn 被设计为“训练无关 (training-free)”,这大大降低了使用门槛。然而,如果能将稀疏模式的预测或学习集成到模型训练过程中,例如通过稀疏性感知训练 (sparsity-aware training) 或可学习的稀疏掩码,可能会进一步优化稀疏模式,实现更高的稀疏度或更好的精度。
  • 硬件通用性: 尽管 CUDA 实现和 Warp 级别的优化在 NVIDIA GPU 上表现出色,但在其他硬件平台(如 AMD GPUTPU 或其他加速器)上的性能和实现可能需要额外的适配和优化。

7.3. 个人启发与批判

7.3.1. 个人启发

  1. 通用性稀疏化的新范式: 论文提出通过块内词元相似性这一普适特征进行稀疏预测,而非依赖任务特定模式。这为实现真正通用的稀疏注意力提供了一条有前景的路径,打破了以往方法在特定领域或任务上的局限性。
  2. 预测与过滤的分层设计: 两阶段在线过滤机制是一个非常精妙的设计。第一阶段在粗粒度上快速识别大部分可跳过区域,通过“选择性压缩”保证精度;第二阶段则在细粒度上、无额外开销地进一步优化。这种分层处理策略有效平衡了预测的开销和稀疏化带来的收益,值得在其他计算优化问题中借鉴。
  3. 工程实践价值: SpargeAttn 的“训练无关”特性使其能够即插即用,与现有量化技术(如 SageAttention)和底层优化(如 FlashAttention)正交并叠加加速,这在实际工程部署中极具吸引力,能够快速为存量模型带来性能提升。
  4. 领域知识的巧妙运用: HilbertCurve 置换在视觉模型中的应用,展示了如何将特定领域(如视觉数据的空间局部性)的先验知识巧妙地融入到通用优化框架中,从而进一步提升性能。这提醒研究者在设计通用方法时,仍可结合特定领域特性进行微调。
  5. LLM 长上下文的潜在增益: 实验结果显示 SpargeAttn 甚至能略微提升 LLM 在长上下文任务中的性能。这暗示稀疏化不仅仅是加速,还可能通过强制模型聚焦于更相关的信息,起到一种“注意力正则化”的作用,对于信息过载的长序列处理可能有所裨益。

7.3.2. 批判性思考

  1. 超参数的“通用性”与“调优成本”矛盾: 论文声称其方法是“通用”的,但超参数 (τ,θ,λ\tau, \theta, \lambda) 却需要针对每个模型、每个注意力层进行网格搜索确定,并依赖于预设的 L1 误差阈值 (l1,l2l_1, l_2)。这种逐层调优的策略虽然能获得最优性能,但其“通用性”主要体现在方法原理上,而非开箱即用性。在实际应用中,尤其对于拥有成百上千个注意力层的超大规模模型,这个调优过程的计算和时间成本会非常高昂。如何实现这些超参数的自动化、自适应甚至可学习的确定,将是提升其真正“可用性”的关键。
  2. “自相似度判断”的阈值敏感性: 用于区分“选择块”和“固定块”的超参数 θ\theta 至关重要。一个不当的 θ\theta 值可能会导致大量重要信息被错误地归为“非自相似块”而强制计算,从而降低稀疏度;或者错误地将重要块归为“自相似块”进行激进压缩,导致精度下降。CosSim 作为自相似度度量,其计算方式是否能在所有模型和数据分布下都提供稳定的、有意义的区分,值得深入探讨。
  3. “无额外开销”的边界: 论文强调第二阶段的稀疏 Warp Online Softmax 滤波器“无额外开销”。这通常意味着其判断逻辑可以融合到底层 CUDA 核函数中,不引入额外的内存访问或计算步骤。然而,“无额外开销”是一个相对概念,它仍然涉及额外的逻辑判断(如 max(mlocal[Iw]mij[Iw])>λmax(m_local[I_w] - m_ij[I_w]) > λ),这些判断虽然高效,但并非零成本。在极端短序列或计算量极小的情况下,这些判断的相对开销可能变得显著。对“额外开销”的更严格量化和不同场景下的性能分析将有助于更全面的评估。
  4. 理论支撑的进一步强化: 论文主要侧重于工程实现和实验验证。对于为什么“长序列稀疏性更高”或“自相似块能被安全地压缩”这些关键观察,如果能提供更深入的理论分析或数学证明,将进一步增强方法的说服力。例如,从信息论或特征表示的角度解释为何长序列中存在更多冗余信息可被稀疏化。
  5. 对 FP16/FP4/INT8 的依赖与数值稳定性: 论文将稀疏化与 SageAttention 的量化技术结合,这涉及 FP16 甚至更低精度 (INT8, FP4) 的计算。低精度计算虽然能带来速度提升,但也可能引入数值稳定性问题。虽然论文宣称保持了精度,但在面对更复杂、更敏感的模型或极端长序列时,稀疏化和量化的叠加效应可能对数值稳定性提出更高挑战,需要进一步的鲁棒性测试。

相似论文推荐

基于向量语义检索推荐的相关论文。

暂时没有找到相似论文。