Paper status: completed

SampleAttention: Near-Lossless Acceleration of Long Context LLM Inference with Adaptive Structured Sparse Attention

Published:06/17/2024
Original LinkPDF
Price: 0.100000
Price: 0.100000
Price: 0.100000
1 readers
This analysis is AI-generated and may not be fully accurate. Please refer to the original paper.

TL;DR Summary

SampleAttention is introduced as an adaptive, near-lossless sparse attention method for long-context LLMs, significantly reducing Time-to-First-Token latency while maintaining accuracy, achieving up to 2.42x TTFT reduction compared to FlashAttention.

Abstract

Large language models (LLMs) now support extremely long context windows, but the quadratic complexity of vanilla attention results in significantly long Time-to-First-Token (TTFT) latency. Existing approaches to address this complexity require additional pretraining or finetuning, and often sacrifice model accuracy. In this paper, we first provide both theoretical and empirical foundations for near-lossless sparse attention. We find dynamically capturing head-specific sparse patterns at runtime with low overhead is crucial. To address this, we propose SampleAttention, an adaptive structured and near-lossless sparse attention. Leveraging observed significant sparse patterns, SampleAttention attends to a fixed percentage of adjacent tokens to capture local window patterns, and employs a two-stage query-guided key-value filtering approach, which adaptively select a minimum set of key-values with low overhead, to capture column stripe patterns. Comprehensive evaluations show that SampleAttention can seamlessly replace vanilla attention in off-the-shelf LLMs with nearly no accuracy loss, and reduces TTFT by up to 2.42×2.42\times compared with FlashAttention.

Mind Map

In-depth Reading

English Analysis

1. Bibliographic Information

1.1. Title

The central topic of the paper is "SampleAttention: Near-Lossless Acceleration of Long Context LLM Inference with Adaptive Structured Sparse Attention". This title clearly indicates the paper's focus on improving the inference efficiency of Large Language Models (LLMs) with very long input sequences by employing a novel, adaptive sparse attention mechanism that aims to maintain high accuracy.

1.2. Authors

The authors are Qianchao Zhu, Jiangfei Duan, Chang Chen, Xiuhong Li, Siran Liu, Guanyu Feng, Xin Lv, Chuanfu Xiao, Dahua Lin, and Chao Yang. Their affiliations are indicated by numerical superscripts, with most authors affiliated with institution 1 and some with 2 or 3. The paper itself does not explicitly list the full names of the institutions, which is common in preprint versions.

1.3. Journal/Conference

The paper was published on arXiv, a preprint server, under the identifier arxiv:2406.15486. arXiv is a widely recognized platform for sharing research preprints in physics, mathematics, computer science, quantitative biology, quantitative finance, statistics, electrical engineering and systems science, and economics. Papers posted on arXiv are not peer-reviewed by arXiv itself, but they are often subsequently submitted to and published in reputable conferences or journals.

1.4. Publication Year

The paper was published at (UTC) 2024-06-17T11:05:15.000Z.

1.5. Abstract

The paper addresses the challenge of Time-to-First-Token (TTFT) latency in Large Language Models (LLMs) due to the quadratic complexity of vanilla attention, especially with long context windows. Existing sparse attention methods often require additional pretraining/finetuning and sacrifice accuracy. The authors first provide theoretical and empirical foundations for near-lossless sparse attention, emphasizing the importance of dynamically capturing head-specific sparse patterns at runtime. To this end, they propose SampleAttention, an adaptive structured and near-lossless sparse attention mechanism. SampleAttention leverages significant sparse patterns by attending to a fixed percentage of adjacent tokens (for local window patterns) and employs a two-stage query-guided key-value filtering approach to adaptively select a minimum set of key-values (for column stripe patterns) with low overhead. Comprehensive evaluations show that SampleAttention can seamlessly replace vanilla attention in off-the-shelf LLMs with nearly no accuracy loss, reducing TTFT by up to 2.42×2.42\times (this value is updated to 5.29×5.29\times in the introduction of the main paper) compared to FlashAttention.

The original source link is https://arxiv.org/abs/2406.15486, and the PDF link is https://arxiv.org/pdf/2406.15486v3.pdf. Its publication status is a preprint on arXiv.

2. Executive Summary

2.1. Background & Motivation

The rapid advancement in Large Language Models (LLMs) has led to models supporting exceptionally long context windows, some exceeding 1 million tokens. While this enables more complex applications like document analysis, code copilot, and prolonged conversations, it introduces a significant challenge: the quadratic computational complexity of the conventional attention mechanism. This quadratic scaling of computation and memory with sequence length results in drastically increased Time-to-First-Token (TTFT) latency, especially during the prefill phase of inference. For instance, processing a 1-million-token context on ChatGLM3-6B can take over 1500 seconds, dominating more than 90% of the TTFT. Such prohibitive latency makes real-time interaction impractical and hinders the deployment of long-context LLMs in many real-world applications.

Prior research has consistently shown that attention scores in LLMs exhibit high sparsity, meaning only a small fraction of query-key pairs significantly contribute to the final attention output. This inherent sparsity presents a promising avenue for optimizing the attention mechanism by computing attention selectively. However, existing sparse attention approaches face two key limitations:

  1. Static Patterns: Many methods use static sparse patterns (e.g., fixed windows, global tokens), failing to capture the dynamic and adaptive nature of attention sparsity across different attention heads, input contents, and model architectures. This often leads to significant accuracy degradation.

  2. Dynamic Methods' Drawbacks: More dynamic approaches either incur substantial computational overhead for index selection, rely on coarse-grained selection struggling with accuracy, or use predefined patterns and fixed budgets that don't adapt to varying sparsity ratios and dynamic patterns.

    The paper's entry point is to effectively exploit this inherent high attention sparsity while overcoming the challenges posed by its adaptive ratio and dynamic patterns. It seeks to develop a runtime-efficient sparse attention mechanism that can seamlessly integrate into off-the-shelf LLMs without compromising accuracy.

2.2. Main Contributions / Findings

The primary contributions and findings of the paper are:

  • Empirical and Theoretical Foundation for Near-Lossless Sparse Attention: The paper provides foundational insights demonstrating that dynamically capturing head-specific sparse patterns at runtime with low overhead is crucial for achieving near-lossless sparse attention. It reveals that sparsity ratios are adaptive across attention heads, input contents, and model architectures, and attention patterns are dynamic, often combining typical column and slash patterns.
  • Introduction of Cumulative Residual Attention (CRA): A novel and robust metric, CRA, is proposed to measure the percentage of attention recall, serving as a reliable indicator for guiding the trade-off between efficiency and accuracy in sparse attention. A consistent positive correlation between the CRA threshold and model accuracy is empirically shown.
  • Proposal of SampleAttention: An adaptive structured and near-lossless sparse attention mechanism is introduced. SampleAttention dynamically determines sparse ratios and patterns at runtime.
  • Novel Two-Stage Query-Guided Key-Value Filtering: This core methodology allows SampleAttention to efficiently identify important column and slash patterns.
    • Stage 1: Query-Guided Chunked Sampling: Estimates the full attention score by sampling attention scores for a few queries across chunknchunk_n equal segments, enabling more accurate pattern detection than prior methods.
    • Stage 2: Score-Based Key-Value Filtering: Decomposes the CRA threshold into separate thresholds for columns (αcα_c) and slashes (αsα_s) to independently filter a minimal set of key-value blocks, which are then merged into a final sparse mask.
  • Automated Hyperparameter Tuning: A method is developed for offline tuning of SampleAttention's hyperparameters (αcα_c, αsα_s, chunknchunk_n) on small profiling datasets across different length ranges, optimizing the accuracy-efficiency trade-off.
  • Hardware-Efficient Implementation: SampleAttention is implemented with IO-awareness through operator fusion and modification of FlashAttention2 kernels to achieve substantial wall-clock time speedups.
  • Significant Performance Gains: SampleAttention seamlessly replaces vanilla attention in off-the-shelf LLMs (ChatGLM, YI, InternLM) with nearly no accuracy loss (consistently above 99% of full attention scores). It reduces TTFT by up to 5.29×5.29\times compared to FlashAttention2 on long contexts (e.g., 1 million tokens), establishing a new Pareto frontier in the accuracy-efficiency trade-off.

3. Prerequisite Knowledge & Related Work

3.1. Foundational Concepts

To understand SampleAttention, a foundational understanding of Large Language Models (LLMs), the Transformer architecture, and the attention mechanism is essential.

3.1.1. Large Language Models (LLMs) and the Transformer Architecture

Large Language Models (LLMs) are advanced artificial intelligence models, typically based on the Transformer architecture, that are trained on vast amounts of text data to understand, generate, and process human language. They excel at tasks like translation, summarization, question answering, and text generation. The Transformer architecture, introduced by Vaswani et al. (2017), revolutionized natural language processing by largely relying on self-attention mechanisms rather than recurrent or convolutional layers. It consists of an encoder-decoder structure (for sequence-to-sequence tasks) or a decoder-only structure (common for generative LLMs). Each Transformer block contains an attention layer (often Self-Attention) followed by a feed-forward network (MLP).

3.1.2. Attention Mechanism

The attention mechanism is the core component of Transformers. It allows the model to weigh the importance of different tokens in an input sequence when processing a specific token. Instead of processing tokens sequentially, attention can process all tokens in parallel and dynamically adjust their influence on the output. For each token, the model generates three vectors: a query (Q), a key (K), and a value (V). The similarity between a query and all keys determines the attention scores, which are then used to weigh the values.

3.1.3. Quadratic Complexity of Attention

The standard attention mechanism involves computing pairwise interactions between all query tokens and all key tokens. If the input sequence has length NN, the computation of the attention scores matrix (QKT\mathbf{Q}\mathbf{K}^T) involves an N×NN \times N matrix multiplication. This leads to a computational complexity of O(N2)O(N^2) with respect to the sequence length NN. Similarly, the memory required to store this attention scores matrix is also O(N2)O(N^2). For short sequences, this is manageable, but for very long contexts (e.g., 100K or 1M tokens), this quadratic scaling becomes a severe bottleneck, leading to extremely high computational costs and memory requirements.

3.1.4. Time-to-First-Token (TTFT) and Prefill Latency

LLM inference typically operates in two phases:

  • Prefill Phase: This is when the model processes the entire input prompt (e.g., a long document) in parallel. During this phase, all query tokens from the prompt attend to all key tokens. This phase generates the first output token and populates the Key-Value (KV) cache for all input tokens. The quadratic complexity of attention during this phase directly translates to TTFT latency, which is the time taken to generate the very first token of the response. For long contexts, prefill latency dominates the overall response time.

  • Decoding Phase: After the first token is generated, subsequent tokens are generated one by one (autoregressively). In this phase, each new query token attends to its previously generated tokens and the KV cache of the original prompt. This phase has a linear complexity O(N)O(N) for each new token, where NN is the current sequence length, but the prompt's KV cache memory still grows with the total sequence length.

    TTFT (or prefill latency) is a critical metric for user experience, especially in interactive applications, as it dictates how long a user waits for any response to appear.

3.1.5. KV Cache

The Key-Value (KV) cache is a memory optimization technique used during the decoding phase of LLM inference. Instead of recomputing the keys and values for all previously generated tokens at each step, these vectors are stored in memory after the prefill phase and appended with new key and value vectors from subsequently generated tokens. This significantly speeds up the decoding phase by avoiding redundant computations. However, for extremely long contexts, the KV cache itself can consume a substantial amount of memory.

3.1.6. Sparse Attention

Sparse attention is a general strategy to address the quadratic complexity of full attention. The core idea is that not all query-key interactions are equally important. By identifying and only computing a subset of the attention scores (i.e., making the attention matrix sparse), sparse attention aims to reduce computational and memory costs from O(N2)O(N^2) to a lower complexity, ideally O(N)O(N). Various approaches define different sparse patterns, such as attending only to a fixed window of tokens, a few global tokens, or randomly selected tokens.

3.1.7. FlashAttention

FlashAttention (Dao et al., 2022) is an optimized implementation of the full attention mechanism. It significantly reduces the memory footprint and speeds up computation by reordering the attention computation and using tiling and fusing operations. Specifically, it computes the softmax in an online fashion, avoiding the need to materialize the large N×NN \times N attention scores matrix in high-bandwidth memory. While FlashAttention addresses the memory bottleneck and improves the constant factor for computation, it does not change the fundamental O(N2)O(N^2) computational complexity of full attention. FlashAttention2 further optimizes this by improving parallelism and work partitioning.

3.2. Previous Works

The paper discusses various prior approaches to sparse attention and KV cache compression, highlighting their limitations that SampleAttention aims to overcome.

3.2.1. Full Attention Formula

The fundamental attention mechanism used in Transformers is Scaled Dot-Product Attention. For a single attention head, given query matrix QRSq×d\mathbf{Q} \in \mathbb{R}^{S_q \times d}, key matrix KRSk×d\mathbf{K} \in \mathbb{R}^{S_k \times d}, and value matrix VRSk×d\mathbf{V} \in \mathbb{R}^{S_k \times d}, where SqS_q is the query sequence length, SkS_k is the key/value sequence length, and dd is the head dimension, the output ORSq×d\mathbf{O} \in \mathbb{R}^{S_q \times d} is computed as:

$ \mathbf { P } = { \mathsf { s o f t m a x } } ( { \frac { \mathbf { Q } \mathbf { K } ^ { T } } { { \sqrt { d } } } } ) \in [ 0 , 1 ] ^ { S _ q \times S_k } $

$ \mathbf { O } = \mathbf { P } \mathbf { V } \in \mathbb { R } ^ { S _ q \times d } $

Where:

  • Q\mathbf{Q} represents the queries.
  • K\mathbf{K} represents the keys.
  • V\mathbf{V} represents the values.
  • SqS_q is the number of queries (often the sequence length of the current input).
  • SkS_k is the number of keys/values (often the sequence length of the context).
  • dd is the dimension of the query, key, and value vectors (head dimension).
  • QKT\mathbf{Q}\mathbf{K}^T computes the dot product similarity between each query and each key.
  • d\sqrt{d} is a scaling factor to prevent large dot product values from pushing the softmax function into regions with tiny gradients.
  • softmax\mathsf{softmax} is applied row-wise (across the SkS_k dimension for each query SqS_q), normalizing the attention scores to sum to 1 for each query. This produces the attention weights matrix P\mathbf{P}.
  • P\mathbf{P} is the attention weights matrix, indicating the importance of each key token for each query token.
  • O\mathbf{O} is the final output of the attention layer, a weighted sum of the value vectors based on the attention weights.

3.2.2. Static Sparse Attention Methods

These approaches define fixed patterns for attention, often based on structural assumptions:

  • LongFormer (Beltagy et al., 2020): Combines local windowed attention with task-specific global attention.

  • BigBird (Zaheer et al., 2020): Uses a combination of local window attention, global tokens, and random attention to achieve O(N)O(N) complexity. The paper assigns BigBird a window size ratio of 8% and a global ratio of 8% for baselines.

  • LongNet (Ding et al., 2023): Replaces full attention with dilated attention to cover longer distances with sparse connections.

  • StreamingLLM (Xiao et al., 2023b): Focuses on infinite length generation by keeping attention sinks (a few initial tokens) and a fixed number of recent tokens. The paper sets its initial attention sink at 4 tokens for baselines.

    Limitation: These methods struggle to capture the dynamic and adaptive nature of attention patterns across different heads and inputs, often leading to accuracy degradation when applied to pretrained LLMs without fine-tuning.

3.2.3. Dynamic Sparse Attention Methods

These methods attempt to adapt the sparse pattern at runtime but face their own challenges:

  • DSA (Liu et al., 2022): Approximates attention patterns using low-rank hidden dimensions, but incurs significant computational overhead for long contexts.

  • HyperAttention (Han et al., 2023): Employs Locality Sensitive Hashing (LSH) to identify important attention scores. Its coarse-grained selection struggles to maintain model accuracy. The paper sets both bucket size and the number of sampled columns to 256 for baselines.

  • MInference (Jiang et al., 2024): A hybrid approach that predefines several sparse patterns (A-shape, Vertical-Slash, Block-Sparse) and matches each attention head to its optimal pattern offline. It then dynamically searches for sparse indices during the prefill phase within a fixed sparsification budget.

    Limitation: DSA has high overhead. HyperAttention is coarse-grained. MInference relies on predefined patterns and fixed budgets, failing to capture the varying sparsity ratios and dynamic patterns that adapt to different input contents and heads at runtime.

3.2.4. KV Cache Compression Methods

These works focus on reducing KV cache memory consumption, which is orthogonal to SampleAttention's focus on prefill computation speedup but can be combined:

  • StreamingLLM (Xiao et al., 2023b): Also applies here by pruning older KV entries.
  • H2O (Zhang et al., 2024c): Dynamically retains a balance of recent and heavy hitter tokens during decoding.
  • FastGen (Ge et al., 2023): Adaptively constructs KV cache based on head-specific policies.
  • SnapKV (Li et al., 2024): Compresses KV cache by selecting clustered critical positions.
  • CHAI (Agarwal et al., 2024): Head pruning methods to reduce KV cache overhead at the attention head level.
  • Quantization approaches (Duanmu et al., 2024; Xiao et al., 2023a; Zhao et al., 2023) to lower KV cache precision.

3.3. Technological Evolution

The evolution of LLM inference efficiency for long contexts can be seen as a progression:

  1. Full Attention: The foundational Transformer mechanism, inherently quadratic.
  2. Optimized Full Attention: Implementations like FlashAttention that improve memory and speed constants but retain O(N2)O(N^2) complexity.
  3. Static Sparse Attention: Early attempts to break O(N2)O(N^2) by applying predefined, fixed sparse patterns (e.g., LongFormer, BigBird, LongNet). These often required retraining or fine-tuning and incurred accuracy penalties.
  4. Dynamic Sparse Attention: More recent efforts to adapt sparse patterns at runtime (DSA, HyperAttention, MInference). While more flexible, they often introduced significant overhead, lacked fine-grained adaptivity, or still relied on some static assumptions.
  5. Adaptive Structured Sparse Attention (SampleAttention): The current paper's contribution, aiming for a truly dynamic and adaptive approach that determines optimal sparsity ratios and patterns per head, per input, at runtime, guided by an accuracy metric (CRA), and implemented efficiently. This represents a step towards combining the best of dynamic adaptation with near-lossless accuracy.

3.4. Differentiation Analysis

Compared to main methods in related work, SampleAttention introduces several core differences and innovations:

  • Adaptive Sparsity Ratio: Unlike MInference (Jiang et al., 2024) and DuoAttention (Xiao et al., 2024) which use fixed sparsification budgets or predefined patterns, SampleAttention dynamically determines the sparsity ratio at runtime for each individual attention head and input prompt. This directly addresses the observation that optimal sparsity varies greatly across heads, contents, and models.

  • Dynamic Sparse Pattern Flexibility: While MInference acknowledges column and slash patterns, it classifies attention heads into fixed categories offline. SampleAttention's two-stage query-guided key-value filtering adaptively identifies critical column and slash patterns on-the-fly, allowing for flexible combinations of these patterns (e.g. Figure 4(c-d)) that are content-dependent. This avoids the limitations of static pattern classification.

  • Cumulative Residual Attention (CRA) as an Accuracy Proxy: SampleAttention introduces CRA as a principled, robust metric to guide the selection of sparse indices. This allows for a direct, data-driven way to balance efficiency and accuracy, ensuring "near-lossless" performance without needing extensive re-training or fine-tuning. Previous dynamic methods often lacked such a direct, runtime-evaluable accuracy indicator.

  • Query-Guided Chunked Sampling: Instead of relying solely on the last query tokens (like some prior works), SampleAttention partitions queries into chunks and samples from each, providing a more comprehensive and accurate estimation of the attention distribution across the entire attention matrix, especially for complex hybrid patterns.

  • Hardware-Efficient Implementation: SampleAttention is built upon FlashAttention2 with IO-awareness, including operator fusion and custom kernels, ensuring that the dynamic selection overhead is minimized and substantial wall-clock time speedups are achieved. This contrasts with some dynamic methods that might have high computational overhead for pattern discovery.

    In essence, SampleAttention stands out by offering a highly adaptive, runtime-efficient, and accuracy-preserving sparse attention mechanism that is guided by a novel metric, making it seamlessly applicable to off-the-shelf LLMs for long context inference.

4. Methodology

SampleAttention is designed to efficiently accelerate long context LLM inference by exploiting the inherent sparsity of the attention mechanism in a near-lossless manner. The core idea is to dynamically identify and retain only the most important key-value pairs (column and slash patterns) that contribute significantly to the attention scores, guided by a robust accuracy metric called Cumulative Residual Attention (CRA).

4.1. Principles

The fundamental principle behind SampleAttention rests on two key observations:

  1. Adaptive Sparsity: The optimal degree of sparsity (sparsity ratio) and the specific patterns of important attention scores are not static. They vary dynamically across different attention heads, input contents, and LLM architectures. This means a fixed, predefined sparse pattern or budget will inevitably lead to suboptimal performance or accuracy degradation.

  2. Structured Patterns: Despite their dynamism, significant attention scores often coalesce into identifiable structured patterns, primarily column stripes (capturing global context, like attention sinks) and slash stripes (capturing local context, like window patterns). These patterns, and their combinations, account for the majority of the attention recall.

    SampleAttention proposes to address these by:

  • Introducing Cumulative Residual Attention (CRA) as a principled metric to balance efficiency and accuracy. CRA measures the percentage of attention recall, indicating how much of the original attention mass is preserved after sparsification.

  • Employing a novel two-stage query-guided key-value filtering method to dynamically identify and select these critical column and slash patterns at runtime with low overhead.

  • Implementing hardware-efficient kernels to minimize the computational cost of this dynamic selection and subsequent sparse attention computation.

    The goal is to achieve maximal efficiency (high speedup) by aggressively sparsifying the attention computation, while maintaining accuracy by ensuring that the selected sparse patterns capture sufficient attention recall as guided by CRA.

4.2. Core Methodology In-depth (Layer by Layer)

The SampleAttention methodology operates in two main stages: Query-Guided Chunked Sampling and Score-Based Key-Value Filtering. These stages work in conjunction to construct a dynamic, adaptive sparse mask M^\hat{\mathbf{M}} that guides the sparse FlashAttention computation.

4.2.1. Initial Full Attention & Sparse Attention Formulation

Before diving into SampleAttention's specific stages, let's reiterate the foundational full attention and the general formulation for sparse attention as presented in the paper.

For a given attention head, with query matrix QRSq×d\mathbf{Q} \in \mathbb{R}^{S_q \times d}, key matrix KRSk×d\mathbf{K} \in \mathbb{R}^{S_k \times d}, and value matrix VRSk×d\mathbf{V} \in \mathbb{R}^{S_k \times d}, the full attention output ORSq×d\mathbf{O} \in \mathbb{R}^{S_q \times d} is:

$ \mathbf { P } = { \mathsf { s o f t m a x } } ( { \frac { \mathbf { Q } \mathbf { K } ^ { T } } { { \sqrt { d } } } } ) \in [ 0 , 1 ] ^ { S _ q \times S_k } $

$ \mathbf { O } = \mathbf { P } \mathbf { V } \in \mathbb { R } ^ { S _ q \times d } $

Where:

  • Q\mathbf{Q}: Query tensor.

  • K\mathbf{K}: Key tensor.

  • V\mathbf{V}: Value tensor.

  • SqS_q: Query sequence length.

  • SkS_k: Key sequence length.

  • dd: Head dimension.

  • P\mathbf{P}: Attention scores matrix after softmax.

  • O\mathbf{O}: Output tensor of the attention layer.

  • softmax\mathsf{softmax}: Row-wise softmax function.

    To approximate this with sparse attention, a binary mask matrix M{0,1}Sq×Sk\mathbf{M} \in \{ 0 , 1 \} ^ { S _ q \times S _ k } is introduced. The sparse attention scores P^\hat{\mathbf{P}} are computed as:

$ \hat { \mathbf { P } } = { \mathsf { s o f t m a x } } ( \frac { \mathbf { Q } \mathbf { K } ^ { T } } { \sqrt { d } } - c ( 1 - \mathbf { M } ) ) $

Where:

  • P^\hat{\mathbf{P}}: The sparse attention scores matrix.
  • cc: A large constant.
  • M\mathbf{M}: A binary mask matrix, where Mij=1M_{ij} = 1 if the attention from query ii to key jj is computed, and Mij=0M_{ij} = 0 if it is masked out. The term c(1M)-c(1-\mathbf{M}) effectively sets the unselected attention logits to a very small negative number, causing their softmax probabilities to become nearly zero. The sparsity ratio is the percentage of attention scores that are masked (i.e., Mij=0M_{ij}=0).

4.2.2. Cumulative Residual Attention (CRA)

The paper introduces Cumulative Residual Attention (CRA) as a key insight to guide the balance between efficiency and accuracy. CRA is defined as the minimum sum of remaining attention probabilities per query after sparsification. In simpler terms, if you sum up the attention probabilities of the selected (unmasked) key-value pairs for each query, CRA quantifies how much of the original attention mass (which sums to 1 for each query in full attention) is retained.

The relationship between the CRA threshold and model accuracy is empirically observed to be a consistent positive correlation (e.g., Figure 5). This means that by setting a target CRA threshold, the model can dynamically select a minimal set of attention indices that achieve this threshold, thus maximizing computational efficiency while preserving accuracy.

The core challenge is that precisely calculating CRA for a given mask M\mathbf{M} requires computing the full attention score matrix first, which is computationally expensive. SampleAttention overcomes this by using its two-stage sampling and filtering process to estimate and select important column and slash patterns that approximate the desired CRA threshold. The insight (illustrated in Figure 7) is that selecting an appropriate number of key column and slash strips can accurately approximate the CRA.

4.2.3. Stage 1: Query-Guided Chunked Sampling

The first stage, Query-Guided Chunked Sampling, aims to efficiently estimate the full attention score distribution without incurring the O(N2)O(N^2) cost of computing all query-key interactions.

Motivation: Prior methods often sample only a few query tokens (e.g., the last lastqlast_q queries) to estimate attention patterns. However, SampleAttention observes that attention patterns can be complex and vary across the entire attention matrix (e.g., Figure 4(c) shows a column pattern significant in the upper half but fading in the lower half). A single, concentrated sampling location might introduce bias and miss important patterns.

Procedure:

  1. Query Partitioning: The full query sequence Q\mathbf{Q} (of length SqS_q) is partitioned into chunknchunk_n equal segments or chunks.
  2. Segment Selection: For each chunk, a small subset of queries (a query block) is selected. The paper mentions Qslice being [Q[iitvblk:iitv]foriinrange(1,chunkn+1)][Q[i*itv-blk:i*itv] for i in range(1, chunk_n + 1)], implying that for each chunk ii, a block of queries ending at iitvi*itv (where itv is the interval between chunks) and starting blk tokens before it, is sampled. This equidistant sampling across chunks is described as low-overhead and more stable than random or bottom sampling, effectively capturing varying patterns across the attention matrix.
  3. Attention Score Computation: For these selected query blocks (Qslice), attention scores are accurately computed against all key tokens KK. This still involves matrix multiplication, but with a significantly reduced number of queries, making it much faster than full QKT\mathbf{Q}\mathbf{K}^T. The casual mask (mcasual) is applied for auto-regressive decoding where queries can only attend to preceding keys. $ \mathbf{\dot{A}} \leftarrow \mathsf{softmax} (\mathbf{Q_{slice}} \mathbf{K}^T / \sqrt{d} + \mathbf{m_{casual}}) $ Where:
    • Qslice\mathbf{Q_{slice}}: The sampled query blocks.
    • K\mathbf{K}: All key tokens.
    • mcasual\mathbf{m_{casual}}: The causal mask applied to prevent future information leakage.
    • A˙\mathbf{\dot{A}}: The attention scores for the sampled query blocks.
  4. Block Reduction: The resulting token-level attention scores A˙\mathbf{\dot{A}} are then reduced at a block size (blk) granularity in both column and slash directions. This means grouping adjacent attention scores into blocks and summarizing them (e.g., summing or averaging) to get block-level scores. $ \mathbf{\dot{A}_c}, \mathbf{\dot{A}_s} \leftarrow \mathsf{block_reduction} (\mathbf{\dot{A}}, \mathsf{blk}) $ Where:
    • A˙c\mathbf{\dot{A}_c}: Block-reduced scores representing column importance.
    • A˙s\mathbf{\dot{A}_s}: Block-reduced scores representing slash importance.
    • blk\mathsf{blk}: The block size for reduction. These block-level scores are then passed to the second stage for key-value filtering.

4.2.4. Stage 2: Score-Based Key-Value Filtering

The second stage, Score-Based Key-Value Filtering, takes the block-level scores from Stage 1 to dynamically select the most important key-value blocks that meet the desired CRA thresholds.

Motivation: Selecting column and slash strips jointly would involve evaluating ncolumn×nslashn_{column} \times n_{slash} combinations, which is computationally prohibitive. SampleAttention decomposes this problem.

Procedure:

  1. Threshold Decomposition: The single global CRA threshold α\alpha is decomposed into two separate thresholds: αc\alpha_c for columns and αs\alpha_s for slashes. This reduces the complexity of selection to ncolumn+nslashn_{column} + n_{slash}.
  2. Independent Selection of Column Strips:
    • The block-reduced column scores A˙c\mathbf{\dot{A}_c} are sorted.
    • A cumulative sum (cumsum) is performed on the sorted column scores.
    • The minimum number of key-value blocks (kck_c) whose cumulative sum meets the threshold αc\alpha_c is determined. This kck_c represents the "quota" of column blocks needed.
    • The top-k operation is then applied to the original A˙c\mathbf{\dot{A}_c} to identify the indices (Ic\mathbf{I_c}) of these kck_c most important column blocks. $ \mathsf{k_c} \leftarrow \mathsf{find_k} (\mathsf{cumsum}(\mathsf{sort}(\mathbf{\dot{A}_c})), \alpha_c) $ $ \mathbf{I_c} \leftarrow \mathsf{arg_topk} (\mathbf{\dot{A}_c}, \mathsf{k_c}) $
  3. Independent Selection of Slash Strips: A similar process is followed for slash strips using A˙s\mathbf{\dot{A}_s} and αs\alpha_s to determine ksk_s and Is\mathbf{I_s}. $ \mathsf{k_s} \leftarrow \mathsf{find_k} (\mathsf{cumsum}(\mathsf{sort}(\mathbf{\dot{A}_s})), \alpha_s) $ $ \mathbf{I_s} \leftarrow \mathsf{arg_topk} (\mathbf{\dot{A}_s}, \mathsf{k_s}) $ Where:
    • find_k\mathsf{find\_k}: A function to determine the minimum number of top-k elements needed to exceed a cumulative sum threshold.
    • arg_topk\mathsf{arg\_topk}: A function to return the indices of the top-k elements.
  4. Extension and Merging of Mask: The identified block indices Ic\mathbf{I_c} and Is\mathbf{I_s} are then "extended" to cover the full attention matrix according to their respective column and slash patterns. These extended patterns (masks) are then merged to create the final block-sparse mask M^\hat{\mathbf{M}}. $ \hat{\mathbf{M}} \leftarrow \mathsf{merge_index} (\mathbf{I_c}, \mathbf{I_s}, \mathsf{itv}) $ Note: The itv parameter likely refers to the interval or chunk size used in the Query-Guided Chunked Sampling stage, ensuring correct alignment during merging. This merged mask M^\hat{\mathbf{M}} is designed to be a near-lossless approximation, capturing critical attention sinks (via columns) and local window patterns (via slashes).

4.2.5. Final Sparse Attention Computation

Finally, the constructed mask M^\hat{\mathbf{M}} is used in a modified FlashAttention kernel to compute the sparse attention output: $ \mathbf{O} \leftarrow \mathsf{sparse_flash_attn} (\mathbf{Q}, \mathbf{K}, \mathbf{V}, \hat{\mathbf{M}}) $ This leverages the hardware-efficient kernels (discussed next) to compute only the unmasked attention scores, significantly reducing computation and IO overhead.

The entire two-stage process is summarized in Algorithm 1, which details the flow for a single attention head. The process is applied across all heads and layers of the LLM.

The following figure (Figure 6 from the original paper) visually summarizes the two-stage methodology of SampleAttention:

该图像是示意图,展示了SampleAttention方法在长上下文大语言模型(LLM)推理中的应用。左侧描述了传统的块状全注意力机制,右侧则分为两个阶段:第一阶段为分块注意力采样,第二阶段为基于得分的块过滤。图中通过不同颜色和形状的标记,说明了采样区域、得分和选择的块,展示了如何有效捕捉局部模式以减少推理延迟。

Algorithm 1: Two-stage Implementation of SampleAttention

Input: Q, K, V, αc, αs [0, 1], chunkn # Stagel: Query-Guided Chunked Attention Sampling Sq
Qslice ← [Q[i*itv−blk:i*itv] for i in range(1, chunkn +1)] Å ← softmax (QsliceK T /√d + mcasual) Åc, Ås ← block_reduction (Å, blk)
# Stage2: Score-Based Key-Value Block Filtering kc ← find_k (cumsum(sort(Åc), α)
Ic ← arg_topk (Åc, kc)
ks ← find_k (cumsum(sort(Ås)), αs) Is ← arg_topk (Ås, ks)
# Extend and Merge Block-sparse Mask across Each Head
¯ ← merge_index(Ic, Is, itv) # Final Sparse FlashAttention with Block Index O ← sparse_flash_attn (Q, K, V, M)

Note on Algorithm 1 representation: The algorithm snippet provided in the paper is not fully formatted in a typical pseudocode style (e.g., missing specific variable definitions for itv, blk, mcasual, and precise function signatures). However, I have presented it exactly as it appears in the paper. The explanation above clarifies the logical flow and implied operations.

4.2.6. Hyperparameter Tuning

SampleAttention introduces three key hyperparameters that control the trade-off between efficiency and accuracy:

  • αc\alpha_c: The desired CRA threshold for columns.

  • αs\alpha_s: The desired CRA threshold for slashes.

  • chunknchunk_n: The number of sampling chunks.

    These hyperparameters are tuned offline using a small profiling dataset. The tuning process involves:

  1. Length-based Segmentation: Discretizing sequence lengths into distinct intervals (e.g., <16K, [16K, 48K), etc.).
  2. Multi-task Tuning: Performing tuning within each segment to find optimal (\alpha_c,\alpha_s,chunkn, chunk_n)configurations that maximize computational efficiency while maintaining accuracy (e.g., relative to FlashAttention2's accuracy and latency). Larger\alpha_c/\alpha_svalues generally increase accuracy at the cost of speedup, whilechunk_nimpacts sampling overhead and accuracy. An appropriatechunk_n$ is crucial to capture sparse structures without excessive overhead.

4.2.7. Hardware-Efficient Implementation

To ensure practical speedups, SampleAttention is designed with IO-awareness and optimized for hardware:

  • Operator Fusion: The dynamic selection process involves a series of small operations (bmm, mask_fill, softmax, reduction). These are fused to minimize IO overhead, which is critical for performance on modern GPUs.

  • Custom Kernel Modification: SampleAttention modifies FlashAttention2 kernels to implement its adaptive structured sparse attention. This allows it to leverage the underlying hardware efficiency of FlashAttention2 while incorporating the dynamic masking logic.

    This integrated approach ensures that the overhead of dynamic pattern identification is mitigated, translating to significant wall-clock time speedups.

5. Experimental Setup

5.1. Datasets

SampleAttention is evaluated on a diverse suite of benchmarks designed for long-context LLMs.

  • RULER (Hsieh et al., 2024):

    • Description: A comprehensive benchmark for evaluating long-context language models. It provides flexible configurations for various sequence lengths.
    • Characteristics: Extends the traditional needle-in-a-haystack test by incorporating diverse types and quantities of "needles" and introduces new tasks like multi-hop tracing and aggregation. This evaluates more complex behaviors beyond simple retrieval.
    • Scale: Encompasses 13 distinct tasks.
    • Usage: Used for comprehensive evaluation of understanding capabilities and for tuning hyperparameters of SampleAttention at specific lengths (e.g., 16K, 32K, 64K, 128K). Synthetic data from RULER is also used for micro-benchmarks.
  • LongBench (Bai et al., 2023):

    • Description: A multi-task benchmark for long context understanding.
    • Characteristics: Comprises single and multi-document QA, summarization, few-shot learning, synthetic tasks, and code completion.
    • Scale: Over 4,750 test cases with task lengths ranging from 4K to 35K.
    • Domain: Bilingual.
  • InfiniteBench (Zhang et al., 2024b):

    • Description: A benchmark specifically designed to evaluate language models' ability to handle, understand, and reason in contexts exceeding an average length of 200K tokens.
    • Characteristics: Comprises 10 unique tasks, each crafted to assess different aspects of language processing and comprehension in extended contexts.

Example Data Sample (Conceptual - as the paper doesn't provide specific examples for these benchmarks): For RULER's needle-in-a-haystack, a data sample might involve a very long document (the "haystack") with a specific, rare piece of information (the "needle") embedded within it, and the model is prompted to retrieve this information. For LongBench's multi-document QA, a query might ask a question that requires synthesizing information from several long documents provided as context. For InfiniteBench, a task could involve complex reasoning over an extremely long code base or a dense legal document.

These datasets are chosen because they collectively cover a wide range of sequence lengths (from 4K to over 200K, and even up to 1M with RULER's flexibility) and diverse tasks, effectively validating the method's performance across different long-context scenarios and ensuring robust evaluation.

5.2. Evaluation Metrics

The paper uses accuracy, speedup, and Time-to-First-Token (TTFT) to evaluate SampleAttention.

5.2.1. Accuracy

Conceptual Definition: Accuracy, in the context of these benchmarks, refers to how well the LLM performs on specific tasks (e.g., answering questions correctly, generating coherent summaries, retrieving information accurately). For SampleAttention, a key goal is to achieve "near-lossless" accuracy, meaning the performance of the sparse attention mechanism should be almost identical to that of full attention. The specific calculation varies by task (e.g., F1 score for QA, ROUGE for summarization, exact match for retrieval), but the paper often reports a normalized score or a percentage relative to the full attention baseline. The Cumulative Residual Attention (CRA) metric internally helps guide this by ensuring a high percentage of attention recall.

Mathematical Formula: The paper does not provide a single overarching formula for "accuracy" as it is a composite measure across diverse tasks. Instead, it refers to "scores" from benchmarks. For specific tasks within RULER, LongBench, and InfiniteBench, common metrics would be:

  • Exact Match (EM): $ EM = \frac{\text{Number of predictions exactly matching the ground truth}}{\text{Total number of predictions}} $
  • F1 Score: The harmonic mean of precision and recall, often used for QA tasks. $ F1 = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} $ Where:
    • Precision=True PositivesTrue Positives+False Positives\text{Precision} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Positives}}
    • Recall=True PositivesTrue Positives+False Negatives\text{Recall} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}}
  • ROUGE (Recall-Oriented Understudy for Gisting Evaluation): Used for summarization tasks, measures the overlap of n-grams, word sequences, and word pairs between the generated summary and reference summaries. (No single simple formula, involves comparing n-gram counts).

Symbol Explanation:

  • Number of predictions exactly matching the ground truth: Count of model outputs that are identical to the correct answer.
  • Total number of predictions: Total number of questions or tasks the model attempted.
  • True Positives: Correctly identified positive instances.
  • False Positives: Incorrectly identified positive instances.
  • False Negatives: Incorrectly identified negative instances.

5.2.2. Time-to-First-Token (TTFT)

Conceptual Definition: TTFT is the time elapsed from when a prompt is submitted to an LLM until the first output token of the response is generated. It directly reflects the prefill latency and is a critical measure of responsiveness for interactive applications. A lower TTFT indicates faster initial response.

Mathematical Formula: $ TTFT = T_{prefill} $ Where:

  • TprefillT_{prefill}: The total time taken by the model to process the entire input prompt and generate the first output token.

Symbol Explanation:

  • TprefillT_{prefill}: Time taken for the prefill phase.

5.2.3. Speedup

Conceptual Definition: Speedup quantifies the performance improvement of a proposed method relative to a baseline method. It is calculated as the ratio of the baseline's execution time to the proposed method's execution time. A speedup of X×X\times means the proposed method is XX times faster.

Mathematical Formula: $ \text{Speedup} = \frac{TTFT_{Baseline}}{TTFT_{ProposedMethod}} $ Where:

  • TTFTBaselineTTFT_{Baseline}: The Time-to-First-Token of the baseline method (e.g., FlashAttention2).
  • TTFTProposedMethodTTFT_{ProposedMethod}: The Time-to-First-Token of the proposed method (SampleAttention).

Symbol Explanation:

  • TTFTBaselineTTFT_{Baseline}: The TTFT of the comparison method.
  • TTFTProposedMethodTTFT_{ProposedMethod}: The TTFT of the SampleAttention method.

5.2.4. Sparsity Ratio

Conceptual Definition: The sparsity ratio measures the percentage of attention scores that are masked or not computed. A higher sparsity ratio implies greater computational savings, but must be balanced against accuracy.

Mathematical Formula: $ \text{Sparsity Ratio} = \frac{\text{Number of masked attention scores}}{\text{Total possible attention scores}} \times 100% $ Where:

  • Number of masked attention scores: The count of query-key pairs where the attention weight is set to (effectively) zero by the mask.
  • Total possible attention scores: The total number of query-key pairs, which is Sq×SkS_q \times S_k for a single head.

5.3. Baselines

SampleAttention is compared against full attention and several representative sparse attention methods. All experiments are conducted on a single NVIDIA-A100 GPU (80GB).

  • Full Attention (Gold Baseline):
    • Description: The standard attention mechanism without any sparsification. It serves as the upper bound for accuracy and the reference for FlashAttention2 speedup comparisons.
  • FlashAttention2 (Dao, 2023):
    • Description: An optimized implementation of full attention that significantly improves speed and memory efficiency, but still has O(N2)O(N^2) computational complexity. This is the primary baseline for speedup comparisons for SampleAttention.
  • MInference (Jiang et al., 2024):
    • Description: A dynamic sparse attention method that uses pre-profiling to determine optimal predefined sparse patterns for each attention head.
    • Representativeness: Represents state-of-the-art dynamic sparse methods that attempt to adapt patterns.
  • BigBird (Zaheer et al., 2020):
    • Description: A static sparse attention method that combines local window attention, global tokens, and random attention.
    • Settings: Assigned a window size ratio of 8% and a global ratio of 8%.
    • Representativeness: Represents prominent static sparse methods.
  • StreamingLLM (Xiao et al., 2023b):
    • Description: A method for efficient long context LLM inference that maintains attention sinks and a fixed number of recent tokens. Primarily for decoding, but evaluated for prefill performance in this context.
    • Settings: Sets its initial attention sink at 4 tokens.
    • Representativeness: Represents methods focused on efficient KV cache management and streaming.
  • HyperAttention (Han et al., 2023):
    • Description: Utilizes Locality Sensitive Hashing (LSH) to identify important entries on the attention map.
    • Settings: Both bucket size and the number of sampled columns are set to 256.
    • Representativeness: Represents LSH-based approximation methods.
  • Hash-Sparse (Pagliardini et al., 2023):
    • Description: Another sparse attention method that likely uses hashing to approximate attention. (Specific details on its mechanism are not elaborated in this paper, but it's listed as a baseline).

    • Representativeness: A generalized sparse attention baseline.

      For SampleAttention's evaluation, hyperparameters are tuned using small-scale tasks from the RULER benchmark at specific lengths (e.g., 16K, 32K, 64K, 128K). These tuned parameters are then applied to tasks across different length ranges to achieve an optimal balance.

5.4. Backbones

The SampleAttention method is evaluated on three widely used open-source LLM variants:

  • ChatGLM4-9B: Based on GLM (Du et al., 2021; GLM et al., 2024), with a 1M context window.

  • YI-9B: Features a 200K context window (Young et al., 2024).

  • InternLM2-7B: Also with a 200K context window (Cai et al., 2024).

    All models are decoder-only transformers (Radford et al., 2018), pretrained via causal language modeling. They share architectural components like rotary positional encoding (Su et al., 2024) and grouped-query attention (Ainslie et al., 2023). Notably, some models achieve long context through continued training with extended sequences (e.g., ChatGLM4), while others use rope scaling for length extrapolation (e.g., YI, InternLM2).

SampleAttention and baseline methods only replace the full-attention implementation during the prompt prefill stage. The KV cache remains uncompressed, and dense attention is used in the decoding phase. This isolates the evaluation to prefill latency reduction.

6. Results & Analysis

6.1. Core Results Analysis

The experimental results highlight SampleAttention's superior accuracy-efficiency trade-off and significant TTFT reduction compared to existing sparse attention methods and FlashAttention2.

The following figure (Figure 8 from the original paper) illustrates the trade-offs between accuracy and speedup for various methods:

该图像是一个示意图,展示了不同模型在相对加速和评分上的表现。图中包含六个子图,分别针对 GLM 和其他模型(如 SampleAttention 和 FlashAttention2)进行了比较,量化了它们在处理任务中的速度与准确性之间的关系。

Analysis of Figure 8 (Accuracy-Efficiency Trade-off): Figure 8 (from the original paper) presents a Pareto frontier analysis, comparing the relative speedup (vs. FlashAttention2) against the accuracy for different sparse attention methods on the RULER benchmark.

  • SampleAttention consistently achieves scores above 99% of full attention accuracy, demonstrating its "near-lossless" efficiency across various models and sequence lengths.

  • Crucially, SampleAttention establishes a new Pareto frontier, indicating that it achieves higher speedups for a given accuracy level, or higher accuracy for a given speedup, compared to existing methods.

  • For example, in the GLM4-9B model (top row):

    • At 32K context, SampleAttention achieves near full attention accuracy with a speedup of approximately 1.7x. Minference also achieves high accuracy but with no speedup. BigBird offers some speedup but with noticeable accuracy degradation. StreamingLLM and HyperAttention show significant accuracy drops despite some speedup.
    • The trend holds for longer contexts (64K and 128K), where SampleAttention consistently delivers the best combination of high accuracy and considerable speedup.
  • The results show that methods like StreamingLLM and HyperAttention suffer from significant performance degradation across tasks, suggesting their inability to capture critical KV elements during the prefill stage effectively. BigBird shows varying degradation but provides a relatively stable speedup due to its static pattern.

    The following are the results from Table 2 of the original paper:

    Benchmark Baseline Task Type Total Score
    Single-Doc QA Multi-Doc QA Summarization Few-shot Learning Synthetic Tasks Code Completion
    LongBench Full Attention 213.12 174.35 109.69 273.87 231.49 121.52 1124.04
    Ours 214.53 174.42 108.92 278.33 234.55 125.18 1135.93
    Minference 212.14 173.37 110.02 274.45 231.87 124.37 1126.22
    BigBrid 207.57 146.45 95.64 272.17 161.60 117.38 1000.81
    StreamingLLM 142.79 129.36 89.71 168.13 19.70 98.43 648.12
    HyperAttention 125.74 119.08 88.05 206.35 32.69 86.35 658.26
    En.Sum En.QA En.MC En.Dia Zh.QA Code.Debug Math. Retr. Find PassKey Retr.Number Retr.KV
    InfiniteBench Full Attention 28.30 12.17 58.95 34.00 13.22 30.71 37.71 100 100 44.0
    Ours 28.30 16.52 61.57 31.50 14.28 31.40 37.14 100 100 49.6
    Minference 28.00 11.39 60.26 28.70 14.81 31.70 39.43 100 100 43.0

Analysis of Table 2 (Accuracy on LongBench and InfiniteBench): Table 2 details the accuracy scores on LongBench and InfiniteBench for SampleAttention ("Ours") and various baselines.

  • LongBench:
    • SampleAttention (Ours) consistently achieves scores comparable to or even slightly exceeding Full Attention across almost all task types (e.g., Single-Doc QA, Multi-Doc QA, Few-shot Learning, Synthetic Tasks, Code Completion). Its total score (1135.93) is even slightly higher than Full Attention (1124.04), potentially due to subtle regularization effects or being within statistical noise.
    • Minference also shows strong accuracy, very close to Full Attention, but as noted in Figure 8, it doesn't provide speedup.
    • BigBird shows a noticeable drop in accuracy compared to Full Attention (e.g., Multi-Doc QA 146.45 vs. 174.35, Summarization 95.64 vs. 109.69).
    • StreamingLLM and HyperAttention exhibit significant accuracy degradation, especially in Synthetic Tasks and Single-Doc QA, demonstrating their limitations for prefill accuracy.
  • InfiniteBench:
    • For tasks specifically designed for extremely long contexts, SampleAttention again maintains performance very close to Full Attention. For some tasks like En.QA, Retr.Number, and Retr.KV, SampleAttention even marginally outperforms Full Attention.
    • Minference also performs well on InfiniteBench, similar to SampleAttention, maintaining high accuracy. These results robustly demonstrate that SampleAttention achieves its "near-lossless" accuracy claim across diverse tasks and extremely long sequence lengths.

The following figure (Figure 11 from the original paper) presents the acceleration speedup benchmarking:

Figure 11. (a) The percentage of time spent on sampling and sparse computation in SampleAttention. (b) Comparison of the TTFT metric using FlashAttention2 as the baseline.

Analysis of Figure 11 (Acceleration Speedup Benchmarking):

  • Figure 11(a) - Sampling Overhead: This chart shows the time breakdown for a full 40-layer model of ChatGLM4-9B across sequence lengths from 32K to 128K. It differentiates between time spent on sampling (dynamic selection of sparse patterns) and sparse computation.
    • As sequence length increases, the relative contribution of sampling overhead to the total execution time diminishes. This is crucial because it implies that the overhead of dynamic pattern identification, while present, becomes proportionally smaller for the scenarios where SampleAttention provides the most benefit (very long contexts).
    • For shorter sequence lengths, the sampling overhead can be a larger fraction, which explains why the performance gains are less pronounced for short sequences (e.g., at 8K, SampleAttention latency is similar to FlashAttention2).
  • Figure 11(b) - TTFT Speedup: This chart compares the TTFT metric of SampleAttention against FlashAttention2 as the baseline, scaling up to 1M sequence length.
    • The TTFT metric shows significant reduction with SampleAttention.
    • At 1M sequence length, SampleAttention achieves a remarkable TTFT reduction of 5.29×5.29\times compared to FlashAttention2. This is a substantial speedup, enabling much faster initial responses for extremely long contexts.
    • The speedup increases with sequence length, demonstrating the scalability and effectiveness of SampleAttention for the target use case.

6.2. Ablation Studies and Parameter Analysis

The paper conducts ablation studies and analyzes the impact of SampleAttention's hyperparameters, specifically CRA thresholds αc\alpha_c, αs\alpha_s, and the numberofsamplingchunks(chunkn)number of sampling chunks (chunk_n).

The following figure (Figure 9 from the original paper) illustrates the impact of CRA thresholds:

Figure 9. The heatmaps under different cases illustrate the impact of choosing different values of \(\\alpha _ { c }\) and \(\\alpha _ { s }\) on the accuracy of calculated blocks). The chun \(\\mathbf { \\nabla } _ { k _ { n } }\) values for the GLM and YI models are set to 2 and 4, respectively. 该图像是热力图,展示了不同 αc\alpha_cαs\alpha_s 值对计算块准确度的影响,分为三部分:(a) GLM4-9B 32K,(b) GLM4-9B 128K,(c) YI-9B 128K。每个部分的颜色深浅代表了准确度的变化,数值显示了对应的准确率和变化百分比。

Analysis of Figure 9 (Impact of αc\alpha_c and αs\alpha_s on Accuracy): Figure 9 displays heatmaps showing how varying the CRA thresholds for column (αc\alpha_c) and slash (αs\alpha_s) patterns influences accuracy under different conditions (GLM4-9B at 32K and 128K, and YI-9B at 128K).

  • General Trend: For all cases, increasing either αc\alpha_c or αs\alpha_s generally leads to improved accuracy. This is expected, as higher thresholds imply retaining more attention mass, thus reducing sparsity but improving attention recall.

  • Trade-off: This improvement comes at the expense of increased computational load (lower speedup), as more key-value pairs are selected.

  • Model/Pattern Specificity:

    • The YI model at 128K (Figure 9(c)) shows significantly greater sensitivity to changes in the slash threshold αs\alpha_s, indicating that slash patterns might be more critical for maintaining accuracy in this specific model/context.
    • Figure 9(a) (GLM4-9B at 32K) demonstrates a more balanced response to changes in both thresholds, suggesting column and slash patterns contribute more equally.
  • Efficiency at Lower Thresholds: Figure 9(b) (GLM4-9B at 128K) shows that even smaller thresholds can deliver sufficiently good scores, allowing for higher speedup while maintaining acceptable accuracy. This emphasizes the importance of tuning for the optimal balance.

    The following are the results from Table 3 of the original paper:

    Model (αc,αs) | chunkn
    1 2 4 6
    GLM(128K) (0.90,0.90) 82.89/1.92 84.70/1.89 84.02/1.70 83.62/1.53
    (0.95,0.95) 84.17/1.64 84.04/1.60 83.14/1.46 83.73/1.33
    Yi(128K) (0.95,0.95) 52.81/2.17 54.54/2.12 53.10/1.89 54.58/1.71
    (0.98,0.98) 56.24/1.29 58.58/1.25 59.37/1.21 59.36/1.13

Analysis of Table 3 (Impact of chunknchunk_n): Table 3 investigates the impact of the numberofsamplingchunks(chunkn)number of sampling chunks (chunk_n) (1, 2, 4, 6) on accuracy (score) and speedup for GLM (128K) and YI (128K) models under fixed CRA thresholds.

  • Optimal chunknchunk_n: An appropriate chunknchunk_n size can yield more cost-effective results. For instance, for GLM(128K) at (\alpha_c=0.90, \alpha_s=0.90), increasing chunknchunk_n from 1 (82.89/1.92) to 2 (84.70/1.89) enhances accuracy (from 82.89 to 84.70) with only a slight reduction in speedup (from 1.92x to 1.89x). This indicates that 2 chunks provide a better representation of the attention pattern without significantly increasing overhead.

  • Excessive chunknchunk_n: However, an excessively large chunknchunk_n may not further improve accuracy and can decrease speedup due to increased sampling overhead. For example, for GLM(128K) with (\alpha_c=0.90, \alpha_s=0.90), increasing chunknchunk_n from 2 to 4 or 6 actually slightly reduces accuracy and consistently reduces speedup. This highlights the trade-off inherent in sampling; more samples mean more accurate pattern detection but also more computation.

    The following are the results from Table 4 of the original paper:

    Sequence Length Average Sparsity in ChatGLM-6B Average Sparsity in InternLM-7B
    4K 88.00% 91.13%
    8K 90.74% 92.72%
    16K 92.52% 93.89%
    32K 93.88% 94.83%
    64K 94.89% 95.89%
    128K 95.84% 96.67%

Analysis of Table 4 (Sparsity Analysis for "Needle in a Haystack" task): This table quantifies the inherent sparsity in ChatGLM-6B and InternLM-7B as sequence length increases.

  • Increasing Sparsity with Length: As the sequence length increases, the average sparsity ratio consistently increases for both models. For ChatGLM-6B, it goes from 88.00% at 4K to 95.84% at 128K. For InternLM-7B, it increases from 91.13% to 96.67%.

  • Implication for Optimization: This finding provides strong empirical evidence for the potential of sparse attention methods. It suggests that for longer contexts, a smaller proportion of KV elements is truly essential, validating the core premise of SampleAttention. The paper notes that with each doubling of length, the proportion of needed KV elements (to maintain the same CRA threshold) decreases by approximately 20%.

    The following are the results from Table 5 of the original paper:

    ratio of top-k 2.5% 10% 20% 40% 80%
    sampling ratio 100% 0.4% 100% 0.4% 100% 0.4% 100% 0.4% 100% 0.4%
    HEAD-1 16.35% 12.74% 26.91% 23.33% 45.99% 42.14% 58.21% 55.34% 96.30% 93.65%
    HEAD-2 55.43% 48.40% 63.89% 58.63% 85.92% 81.98% 89.07% 84.21% 99.15% 98.08%
    HEAD-3 93.20% 90.44% 98.32% 97.62% 99.14% 98.43% 99.41% 99.12% 99.98% 99.66%

Analysis of Table 5 (Effectiveness of Sampling): This table validates the effectiveness of SampleAttention's chunked sampling method by comparing the CRA achieved using full sampling (100% sampling ratio) against a low sampling ratio (0.4%, with chunkn=2chunk_n=2) across different top-k stripe ratios for three different heads.

  • Accurate Approximation: The CRA obtained from the low 0.4% sampling ratio is consistently very close to the CRA obtained from full attention scores. For instance, in HEAD-3, at a 40% top-k ratio, full sampling yields 99.41% CRA, while 0.4% sampling yields 99.12%, a negligible difference.

  • Decreasing Difference with More Strips: The difference between the two sampling ratios in CRA decreases as the ratio of top-k stripes increases. This implies that as more important stripes are selected, the approximation becomes even more accurate.

  • Efficiency and Stability: This demonstrates that chunked sampling is both simple and efficient in accurately estimating CRA, providing a reliable basis for the Score-Based Key-Value Filtering stage.

    The following figure (Figure 10 from the original paper) presents the cross-task robustness of tuned hyperparameters:

    Figure 10. Results from offline tuning and evaluation of (a) GLM4- 9B and (b) InternLM2-7B across RULER, LongBench, and InfiniteBench benchmarks. Different tasks share the same hyperparameters from offline tuning when sequence lengths fall within the same range.

    Analysis of Figure 10 (Cross-Task Robustness): Figure 10 demonstrates the cross-task robustness of SampleAttention's hyperparameters. It shows results for GLM4-9B (a) and InternLM2-7B (b) where hyperparameters were tuned offline on subsets of the RULER benchmark and then applied to LongBench and InfiniteBench tasks when sequence lengths fall within the same range.

  • Generalization: The results show that hyperparameters tuned on a subset of tasks (from RULER) generalize effectively across diverse benchmarks (LongBench, InfiniteBench). This is a crucial finding, indicating that SampleAttention's tuning method produces robust configurations.

  • Near-Lossless Performance: For accuracy-optimized settings, GLM4's and InternLM2's performance maintained near-lossless performance on different tasks from LongBench and InfiniteBench, confirming robust cross-domain adaptability.

  • Consistent Speedup: Consistent speedup gains were observed for the same model under identical sequence lengths across tasks, further validating the tuning approach. This implies that users don't need to retune for every new task within a sequence length range, simplifying adoption.

Summary of Hyperparameter Tuning Results (Table 6 from Appendix): The following are the results from Table 6 of the original paper:

range of sequence length < 16K [16K,48K) [48K,80K) [80K,112K) >=112K
models Acc I Spd GLM4 InternLM2 GLM4 InternLM2 GLM4 InternLM2 GLM4 InternLM2 GLM4 InternLM2
CRA Column 0.98 | 0.85 0.98 | 0.85 0.95 | 0.85 0.95 | 0.80 0.95 | 0.80 0.92 | 0.80 0.90 | 0.80 0.92 | 0.80 0.95 | 0.80 0.95 | 0.80
CRA Slash 0.90 | 0.85 0.95 | 0.85 0.95 | 0.80 0.90 | 0.80 0.95 | 0.80 0.92 | 0.80 0.90 | 0.80 0.90 | 0.80 0.85 | 0.80 0.90 | 0.80
Num of Chunks 111 111 111 111 111 111 211 211 111 211

Table 6 (from Appendix) provides the specific hyperparameter settings (CRA Column (\alpha_c), CRA Slash (\alpha_s), and NumofChunks(chunkn)Num of Chunks (chunk_n)) for GLM4 and InternLM2 across different sequence length ranges, optimized for the highest speedups while retaining at least 90% accuracy.

  • The values show dynamic adaptation: for instance, αc\alpha_c and αs\alpha_s tend to vary across length ranges and between models, reflecting the adaptive nature of sparsity.
  • chunknchunk_n is often set to 1, indicating that even a single chunk can be effective, but for some longer ranges (e.g., [80K, 112K) for GLM4 and InternLM2, and >=112K for InternLM2), chunknchunk_n is increased to 2, suggesting the need for more granular sampling to maintain accuracy at greater lengths.

6.3. Qualitative Analysis (Attention Visualizations)

The following figures (Figures 12 and 13 from the original paper) show visualizations of attention patterns:

该图像是示意图,展示了不同层(Layer0, Layer4, Layer8, Layer12)在序列长度变化下的注意力模式。每个子图呈现对应层的注意力矩阵,矩阵的颜色变化反映了注意力强度和稀疏性。 该图像是示意图,展示了不同层(Layer0, Layer4, Layer8, Layer12)在序列长度变化下的注意力模式。每个子图呈现对应层的注意力矩阵,矩阵的颜色变化反映了注意力强度和稀疏性。

该图像是示意图,展示了不同层级(Layer16、Layer20、Layer24)在序列长度为 k 时的注意力模式分布。各图分布了注意力权重,显示了层级特异性稀疏模式的动态变化。

Analysis of Figures 12 & 13 (Visualization of Attention Score): Figures 12 and 13 display heatmaps of sparse attention patterns across different attention heads in ChatGLM3-6B (28 layers × 32 heads) for a 61K sequence length. These visualizations are based on row-by-row filtering with a CRA threshold of α=0.95\alpha = 0.95.

  • Column Stripes: Figure 12 (Layer 0, Head 10) clearly shows prominent column stripes, particularly towards the beginning of the context. These embody crucial global contextual information (e.g., attention sinks or important initial tokens).

  • Slash Stripes: Figure 12 (Layer 4, Head 19) and Figure 13 (Layer 16, Head 4) illustrate slash patterns, which maintain connections between contexts of regular intervals, typically capturing local window patterns or dependencies over recent tokens.

  • Hybrid Patterns: Figure 12 (Layer 8, Head 24) and Figure 13 (Layer 20, Head 29) demonstrate intricate combinations of both column and slash patterns on a single head. This provides strong visual evidence for the paper's claim that attention patterns are complex and often a composition of these two primary types.

  • Head-Specific Variation: The visualizations clearly show that the specific sparse patterns vary significantly across different attention heads and layers, reinforcing the need for an adaptive, head-specific approach like SampleAttention.

    These visualizations serve as strong qualitative evidence supporting the paper's core hypothesis about the existence of dynamic and structured column and slash patterns in LLM attention, which SampleAttention is designed to exploit.

7. Conclusion & Reflections

7.1. Conclusion Summary

This paper, "SampleAttention: Near-Lossless Acceleration of Long Context LLM Inference with Adaptive Structured Sparse Attention", successfully addresses the critical challenge of high Time-to-First-Token (TTFT) latency in Large Language Models (LLMs) operating with extremely long context windows. The authors identify that the inherent sparsity of the attention mechanism can be leveraged, but only if the adaptive sparsity ratios and dynamic structured patterns (primarily column and slash stripes) across different attention heads and input contents are effectively captured at runtime.

SampleAttention is proposed as an innovative solution that adaptively determines these sparse patterns. Its key contributions include the introduction of Cumulative Residual Attention (CRA) as a robust indicator for guiding accuracy-efficiency trade-offs, and a novel two-stage query-guided key-value filtering approach. This two-stage process efficiently estimates attention distribution through chunked sampling and then filters a minimal set of critical key-value blocks based on decomposed CRA thresholds for column and slash patterns. Complemented by automated hyperparameter tuning and hardware-efficient FlashAttention2 kernels, SampleAttention achieves remarkable performance. Experimental results demonstrate that it can seamlessly replace vanilla attention in off-the-shelf LLMs (ChatGLM, YI, InternLM) with consistently near-lossless accuracy (often >99% of full attention) while significantly reducing TTFT by up to 5.29×5.29\times for 1-million-token contexts.

7.2. Limitations & Future Work

The paper implicitly or explicitly points to several limitations and potential avenues for future work:

  • Sampling Overhead for Short Sequences: As shown in Figure 11(a), for shorter sequence lengths (e.g., 8K), the sampling overhead of SampleAttention can be significant enough that its latency remains nearly identical to FlashAttention2. This implies that SampleAttention's benefits are primarily realized for very long contexts, and it may not be the optimal choice for shorter prefill prompts.
  • Hyperparameter Tuning Complexity: While the paper proposes an automated offline tuning method, this still requires a small profiling dataset and multi-task tuning across different length ranges. This process, though automated, adds an initial setup cost for each new model or architecture. The paper mentions "cost-effective threshold values," implying a balance that might not be universally simple to find for all scenarios.
  • Orthogonality with KV Cache Compression: The paper explicitly states that SampleAttention focuses on prefill computation overhead and is orthogonal to KV cache eviction or compression approaches. This means while SampleAttention improves TTFT, it doesn't directly address the memory consumption of the KV cache during the decoding phase.
  • Generalization of CRA: While CRA is shown to correlate with accuracy, its universal applicability and optimal threshold setting across all possible LLM tasks and architectures might require further theoretical backing or extensive empirical validation beyond the benchmarks presented.
  • Dynamic chunknchunk_n: Currently, chunknchunk_n is a tuned hyperparameter. Future work could explore making chunknchunk_n dynamically adaptive at runtime, similar to αc\alpha_c and αs\alpha_s, to further optimize the sampling overhead.

7.3. Personal Insights & Critique

The paper presents a highly practical and effective solution to a major bottleneck in long context LLM inference.

Personal Insights:

  • Principled Accuracy Guidance: The introduction of CRA as a robust indicator of attention recall is a significant insight. It provides a principled, quantifiable way to navigate the accuracy-efficiency trade-off, moving beyond heuristic or fixed sparse patterns. This is a crucial differentiator from many prior sparse attention methods that often sacrificed accuracy arbitrarily.
  • Adaptive Structured Sparsity: The empirical observation that attention patterns are head-specific, content-aware, and model-aware, and often manifest as combinations of column and slash stripes, is very valuable. SampleAttention's ability to dynamically identify and combine these patterns at runtime is a robust approach. The visualizations in the appendix strongly support this claim.
  • Practicality and Generalizability: The method's ability to seamlessly replace vanilla attention in off-the-shelf LLMs without additional pretraining or finetuning is a huge advantage for practical adoption. The cross-task robustness of tuned hyperparameters further enhances its usability.
  • Hardware-Aware Design: Integrating with FlashAttention2 and implementing operator fusion demonstrates a deep understanding of practical system-level optimization, ensuring that the theoretical gains translate into real-world wall-clock speedups.

Critique / Areas for Improvement:

  • Defining CRA more concretely: While the definition is given, a more precise mathematical formulation of CRA itself (beyond "minimum sum of remaining attention probabilities") might be beneficial for theoretical rigor. How is "remaining attention probability" specifically calculated after sparsification and how does it relate to the softmax in the sparse attention formula?

  • Clarity on block_reduction and merge_index: Algorithm 1 is concise but lacks specific details on the block_reduction and merge_index functions. For a truly beginner-friendly explanation, the exact operations (e.g., sum, max, average for reduction; logical OR for merging) would be helpful.

  • Sensitivity to blk parameter: The block size blk is mentioned in Algorithm 1 but not explicitly discussed as a hyperparameter or tuned. Its impact on the granularity of pattern detection and overhead could be significant.

  • Comparison to Hash-Sparse: Hash-Sparse is listed as a baseline, but its performance is not explicitly detailed in the main Result & Analysis tables (Table 2 and Figure 8 do not include it). Providing its performance would complete the comparison with all listed baselines.

  • Energy Efficiency: While speedup is crucial, long context inference also has significant energy consumption implications. An analysis of SampleAttention's energy efficiency compared to FlashAttention2 and other baselines would be a valuable addition.

    Overall, SampleAttention is an impressive piece of work that makes significant strides in accelerating long context LLM inference while maintaining high accuracy, pushing the boundaries of what's possible for real-time, high-context LLM applications. Its adaptive and principled approach provides a strong foundation for future research in sparse attention.

Similar papers

Recommended via semantic vector search.

No similar papers found yet.