Paper status: completed

Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

Published:06/16/2024
Original LinkPDF
Price: 0.100000
Price: 0.100000
1 readers
This analysis is AI-generated and may not be fully accurate. Please refer to the original paper.

TL;DR Summary

This paper presents Quest, a query-aware KV cache selection algorithm, enhancing the efficiency of long-context LLM inference. By tracking critical key values, Quest achieves up to 7.03x speedup in self-attention while maintaining high accuracy on long-dependency tasks.

Abstract

As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128K or 1M tokens are becoming increasingly prevalent. However, long-context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query. To this end, we propose Quest, a query-aware KV cache selection algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 2.23x self-attention speedup, which reduces inference latency by 7.03x while performing well on tasks with long dependencies with negligible accuracy loss. Code is available at http://github.com/mit-han-lab/Quest .

Mind Map

In-depth Reading

English Analysis

1. Bibliographic Information

1.1. Title

The central topic of the paper is "Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference". This title indicates that the paper proposes a method named Quest, which leverages a concept called query-aware sparsity to improve the efficiency of inference for Large Language Models (LLMs) that handle long contexts.

1.2. Authors

The authors of the paper are Jiaming Tang, Yilong Zhao, Kan Zhu, Guangxuan Xiao, Baris Kasikci, and Song Han. Their affiliations include:

  • Jiaming Tang: 1 (likely a university/institution), 2 (another affiliation, possibly MIT-Han Lab)
  • Yilong Zhao: 1, 3
  • Kan Zhu: 3
  • Guangxuan Xiao: 2
  • Baris Kasikci: 3
  • Song Han: 2, 4 Based on the code repository link (github.com/mit-han-lab/Quest), affiliations 2 and 4 likely refer to the MIT-Han Lab.

1.3. Journal/Conference

The paper was published on arXiv, a preprint server. While not a peer-reviewed journal or conference proceeding at the time of its arXiv posting, arXiv is a highly influential platform for rapid dissemination of research in fields like AI, often preceding formal publication. The mit-han-lab affiliation suggests a strong academic research background, typically associated with top-tier AI conferences (e.g., NeurIPS, ICML, ICLR, AAAI) or journals.

1.4. Publication Year

The publication timestamp (UTC): 2024-06-16T01:33:02.000Z indicates the paper was published in 2024.

1.5. Abstract

The paper addresses the challenge of efficient inference for long-context Large Language Models (LLMs), which increasingly feature context windows of up to 128K or 1M tokens. The primary bottleneck for inference speed in these models is identified as the loading of a large KV cache during self-attention. While previous work noted that only a small fraction of tokens are critical for attention outcomes, the authors observe that this criticality is highly dependent on the current query.

To address this, the paper proposes Quest, a query-aware KV cache selection algorithm. Quest maintains minimal and maximal Key values for each KV cache page as metadata. During inference, it uses the current Query vector along with this metadata to estimate the criticality of each page. By only loading the Top-K critical KV cache pages for self-attention, Quest aims to significantly accelerate self-attention without compromising accuracy.

The experimental results show that Quest achieves up to 7.03×7.03\times self-attention speedup, which translates to a 2.23×2.23\times reduction in inference latency (end-to-end), while maintaining strong performance on tasks requiring long dependencies with negligible accuracy loss.

2. Executive Summary

2.1. Background & Motivation

The rapid advancement and adoption of Large Language Models (LLMs) have led to a growing demand for long-context capabilities, enabling applications like multi-round conversations and long document analysis. Modern LLMs can now process context windows ranging from 128K to even 1M tokens. However, this extended context window introduces a significant performance bottleneck during inference.

The core problem the paper aims to solve is the substantial slowdown in LLM inference speed as the sequence length grows. This slowdown is primarily attributed to the need to load a large Key-Value (KV) cache for every self-attention operation, especially during the decode stage where each new token generation requires attending to all previously generated tokens. For instance, a Llama 7B model with a 32K context length can have a KV cache occupying 16GB, and loading this cache can account for over 50% of the inference latency.

Previous research has identified that not all tokens in the KV cache are equally important; only a small subset of critical tokens truly dictates the attention outcomes. This insight suggests a potential for sparsity in self-attention to reduce the computational and memory load. However, the paper introduces a crucial observation: the criticality of a token is not static but highly dependent on the query. A token considered non-critical for one query might become highly critical for another. This dynamic nature means that query-agnostic pruning or eviction strategies (which discard tokens based on historical importance or fixed windows) might prematurely remove essential information, leading to accuracy degradation, especially in tasks requiring long-range dependencies.

Therefore, the paper's entry point and innovative idea is to develop a query-aware mechanism for identifying critical tokens (or KV cache pages) on-the-fly, ensuring that only the most relevant parts of the KV cache are accessed for self-attention, thereby accelerating inference while preserving accuracy.

2.2. Main Contributions / Findings

The primary contributions and key findings of the paper are:

  • Observation of Query-Aware Sparsity: The paper highlights and quantifies the observation that the criticality of KV cache tokens is highly dynamic and depends significantly on the current query token. This insight forms the foundation for their proposed method.

  • Quest Algorithm: The introduction of Quest, an efficient and accurate query-aware KV cache selection algorithm. Quest operates at page granularity, storing minimal and maximal Key values as metadata for each KV cache page. It then uses the Query vector and this metadata to estimate page criticality, selecting only the Top-K critical pages for self-attention.

  • Dedicated Operator Designs and Implementations: Quest is implemented with specialized CUDA kernels built upon FlashInfer, ensuring practical acceleration and demonstrating the feasibility of its approach.

  • Significant Efficiency Gains:

    • Achieves up to 7.03×7.03\times self-attention latency reduction compared to FlashInfer at 32K context length with a 2048 token budget.
    • Leads to a 2.23×2.23\times end-to-end inference speedup (decode stage) with 4-bit weight quantization and 1.74×1.74\times speedup with FP16 weights at 32K sequence length.
    • Quantitatively reduces memory movement by leveraging query-aware sparsity.
  • Negligible Accuracy Loss: Quest maintains high accuracy across various long-context benchmarks, including PG19 language modeling, passkey retrieval tasks (10K and 100K context lengths), and six datasets from LongBench, often matching or closely approaching the performance of models with a full KV cache with significantly reduced budgets.

  • Outperformance of Baselines: Compared to existing KV cache eviction algorithms like H2O, TOVA, and StreamingLLM, Quest consistently outperforms them in terms of accuracy for a given KV cache budget and achieves up to 4.5×4.5\times self-attention latency reduction for comparable accuracy targets.

    In essence, Quest solves the problem of inefficient long-context LLM inference by intelligently and dynamically pruning the KV cache based on the current query, providing substantial speedups without sacrificing the model's ability to handle long-range dependencies.

3. Prerequisite Knowledge & Related Work

3.1. Foundational Concepts

To understand Quest, a beginner needs to grasp several core concepts related to Large Language Models (LLMs) and their inference mechanisms.

  • Large Language Models (LLMs): These are advanced artificial intelligence models, typically based on the Transformer architecture, trained on vast amounts of text data to understand, generate, and process human language. They are known for their ability to perform various Natural Language Processing (NLP) tasks, such as text generation, translation, summarization, and question answering. Examples include GPT-series, Llama, and Claude.

  • Transformer Architecture: The Transformer is a neural network architecture introduced by Vaswani et al. (2017) that revolutionized sequence modeling. Its key innovation is the self-attention mechanism, which allows the model to weigh the importance of different words in an input sequence when processing each word. It consists of encoder and decoder blocks, though many modern LLMs primarily use the decoder-only variant for generative tasks.

  • Self-Attention Mechanism: This is a core component of the Transformer architecture. It enables a model to consider different parts of the input sequence (or previous output sequence in a decoder) when producing an output for a specific position. For each token, self-attention calculates three vectors:

    • Query (Q): Represents the current token asking "what should I pay attention to?"
    • Key (K): Represents all tokens offering "what I have to offer."
    • Value (V): Represents the actual information content of all tokens. The attention score between a Query and a Key determines how much Value to "attend" to. The output of self-attention for a given query is a weighted sum of the Value vectors, where the weights are derived from the attention scores. The standard formula for Attention is: $ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $ Where:
    • QQ is the matrix of Query vectors. Each row corresponds to a query for a token.
    • KK is the matrix of Key vectors. Each row corresponds to a key for a token.
    • VV is the matrix of Value vectors. Each row corresponds to a value for a token.
    • QKTQ K^T calculates the dot product (similarity) between each query and all keys.
    • dk\sqrt{d_k} is a scaling factor, where dkd_k is the dimension of the Key vectors. This scaling helps prevent the dot products from becoming too large, which can push the softmax function into regions with extremely small gradients.
    • softmax\mathrm{softmax} is an activation function that converts raw scores into probabilities, ensuring that the attention weights sum to 1.
    • The result is a weighted sum of the Value vectors, where weights indicate the importance of each Value to the query.
  • KV Cache (Key-Value Cache): In auto-regressive generation (token-by-token decoding) in LLMs, for each new token generated, the model needs to perform self-attention over all previously generated tokens. Recomputing the Key and Value vectors for all past tokens at each step is computationally expensive. To optimize this, LLMs typically store the Key and Value vectors of past tokens in a KV cache (a memory buffer). For a new token, its Query vector is computed and then used to attend to the Keys and Values stored in the KV cache. As more tokens are generated, the KV cache grows, leading to increased memory consumption and memory bandwidth requirements for loading it.

  • Context Window / Context Length: This refers to the maximum number of tokens an LLM can consider as input at any given time. A long-context LLM implies a large context window, allowing the model to process and understand longer texts or maintain longer conversations. The size of the KV cache is directly proportional to the context length.

  • Inference Latency & Speedup:

    • Inference Latency: The time it takes for an LLM to generate a response or a single token. Lower latency is better.
    • Speedup: A measure of how much faster a new method is compared to a baseline. If a method takes TnewT_{new} time and a baseline takes TbaselineT_{baseline} time, the speedup is Tbaseline/TnewT_{baseline} / T_{new}. For example, a 2×2\times speedup means it's twice as fast.
  • Perplexity (PPL): A common metric for evaluating language models. It measures how well a probability distribution (the language model) predicts a sample. Lower perplexity indicates a better model, meaning it assigns higher probabilities to the actual sequence of tokens. The formula for perplexity for a given sequence of NN tokens W=(w1,w2,...,wN)W = (w_1, w_2, ..., w_N) is: $ \mathrm{PPL}(W) = P(w_1, w_2, ..., w_N)^{-\frac{1}{N}} = \sqrt[N]{\frac{1}{P(w_1, w_2, ..., w_N)}} $ In practice, using log probabilities to avoid underflow: $ \mathrm{PPL}(W) = \exp\left(-\frac{1}{N} \sum_{i=1}^N \log P(w_i | w_1, ..., w_{i-1})\right) $ Where:

    • WW is the sequence of tokens.
    • NN is the total number of tokens in the sequence.
    • P(w1,...,wN)P(w_1, ..., w_N) is the joint probability of the entire sequence.
    • P(wiw1,...,wi1)P(w_i | w_1, ..., w_{i-1}) is the probability of token wiw_i given all previous tokens w1,...,wi1w_1, ..., w_{i-1}, as predicted by the language model.
    • log\log denotes the natural logarithm.
    • exp\exp is the exponential function (e raised to the power of the argument).
  • FlashAttention: An optimized self-attention algorithm that significantly reduces memory usage and improves speed by restructuring the attention computation to be more IO-aware (Input/Output-aware). It performs attention in blocks and avoids writing large intermediate matrices to High-Bandwidth Memory (HBM), instead using faster SRAM on the GPU. This makes FlashAttention particularly beneficial for long sequences but it still needs to load the KV cache.

  • PageAttention (Paged KV Cache): A memory management technique, similar to virtual memory paging in operating systems, applied to the KV cache. Instead of allocating a contiguous block of memory for the KV cache, PageAttention divides the KV cache into fixed-size pages. These pages can be stored non-contiguously in memory, allowing for more flexible and efficient memory allocation, especially in multi-user or dynamic batching scenarios. Quest builds upon PageAttention by performing KV cache selection at this page granularity.

3.2. Previous Works

The paper contextualizes Quest by discussing existing approaches to long-context LLM inference and KV cache compression/eviction. These prior works largely focus on identifying and discarding less important parts of the KV cache to save memory and computation.

  • H2O (Heavy-Hitter Oracle): (Zhang et al., 2023b) This method aims to retain only a limited budget of important KV cache tokens. It determines importance based on the sum of historical attention scores. Tokens that have accumulated high attention scores over time are considered heavy-hitters and are kept.

    • Limitation: H2O relies on historical information. As highlighted by Quest, token criticality is dynamic and query-dependent. A token important in the past might not be important for the current query, and conversely, a historically unimportant token might become critical. Discarding tokens based on history can lead to low recall rates of truly critical tokens for the current query. Additionally, H2O often requires computing the full O(N2)O(N^2) attention map to gather historical scores, which can prevent the use of optimizations like FlashAttention in some contexts.
  • FastGen: (Ge et al., 2024) This work refines the idea of KV cache selection by applying a more sophisticated strategy for choosing which tokens to keep. It categorizes tokens into different types and uses an adaptive KV cache compression approach.

    • Limitation: Similar to H2O, if its selection strategy is not query-aware, it risks discarding tokens that might be critical for a future query.
  • TOVA (Transformers are Multi-State RNNs): (Oren et al., 2024) This method simplifies the KV cache eviction policy. It decides which tokens to permanently discard based solely on the current query.

    • Limitation: While query-aware to some extent in its eviction decision, if it permanently discards tokens, it still faces the risk of losing information that might become critical for future queries. Quest, in contrast, doesn't discard tokens from the total KV cache but rather selects a subset of pages to attend to.
  • StreamingLLM: (Xiao et al., 2023) This approach addresses infinitely long texts by using attention sinks and a finite KV cache. It essentially keeps a few initial tokens (sinks) and a sliding window of recent tokens, discarding older tokens outside this window.

    • Limitation: StreamingLLM can only focus on the most recent context window. If critical information or a passkey appears outside this window, it cannot provide the correct answer, thus failing long-dependency tasks where important information might be far in the past.
  • SparQ (SparQ Attention): (Ribar et al., 2023) This method aims to compute approximate attention scores by channel pruning and then selects important tokens based on these approximations.

    • Limitation: The paper notes that SparQ has not been widely validated for tasks with long dependencies, and channel-level sparsity might pose challenges for practical acceleration. Quest's approach of approximating the upper bound of attention scores at page granularity and then performing full attention on selected pages is potentially more robust for accuracy.

3.3. Technological Evolution

The evolution of LLMs has seen a continuous push towards handling longer context windows.

  • Early LLMs (e.g., GPT-1/2, BERT): Often had context windows limited to a few hundred or a couple of thousand tokens (e.g., BERT 512 tokens, GPT-2 1024 tokens), primarily due to computational constraints of quadratic attention and memory limitations.

  • Scaling with Rotary Position Embeddings (RoPE): The introduction of techniques like Rotary Position Embeddings (RoPE) (Su et al., 2023) enabled better extrapolation to longer sequences. Fine-tuning models with RoPE using various scaling methods (e.g., LongChat to 32k, Yarn-Llama-2 to 128k) allowed for significant increases in context length.

  • Beyond 1M Tokens: Recent advancements have even pushed context windows beyond 1M tokens (e.g., Liu et al., 2024a), indicating a strong trend towards models that can process extremely long inputs.

  • Commercial Deployment: Large models like GPT-4 Turbo (128k context) and Claude-2 (200k context) are already deployed, highlighting the industrial demand for long-context capabilities.

    However, this technological evolution of expanding context windows has created a critical challenge for inference efficiency. As the context length grows, the KV cache size increases proportionally, making KV cache loading a dominant factor in inference latency. This is where Quest positions itself: to boost long-context inference by intelligently managing the KV cache to overcome this bottleneck, without sacrificing the long-dependency capabilities that these models are designed for.

3.4. Differentiation Analysis

Quest differentiates itself from previous KV cache management and sparsity methods primarily through its query-aware selection mechanism and its approach to KV cache retention.

  • Query-Aware vs. Query-Agnostic/Historical:

    • Previous methods (e.g., H2O, StreamingLLM): These methods are largely query-agnostic or rely on historical attention scores or fixed sliding windows. H2O prunes based on past importance, StreamingLLM keeps a fixed window and attention sinks. As Figure 2 clearly demonstrates, token criticality is dynamic and strongly dependent on the current query. These methods risk prematurely discarding tokens that might become critical for a future query.
    • Quest: Directly addresses this by explicitly considering the Query vector when estimating page criticality. This query-awareness allows Quest to dynamically identify truly relevant KV cache pages for the current generation step, as shown in Figure 4 where Quest maintains a much higher recall rate of critical tokens compared to H2O.
  • Selection vs. Eviction/Discarding:

    • Previous eviction algorithms (e.g., H2O, TOVA, StreamingLLM): These methods discard tokens or entire KV cache parts from memory, meaning they are permanently removed. This can lead to accuracy degradation on long-dependency tasks if a discarded token later becomes crucial.
    • Quest: Retains all of the KV cache (it doesn't permanently discard anything) but selects only a Top-K subset of KV cache pages for the self-attention computation at each step. This means potentially critical tokens are always available, even if they were not selected in previous steps. The memory movement is reduced by only loading the metadata for all pages and then the full content for the selected Top-K pages.
  • Granularity and Approximation:

    • Quest operates at page granularity (like PageAttention), which is efficient for memory management. It also uses a clever upper-bound approximation for attention scores based on min/max Key values per channel within a page, reducing the overhead of criticality estimation. This is a practical approach for quickly identifying potentially critical pages.

    • SparQ also uses approximation but at a channel-level which the authors note might have practical acceleration challenges and hasn't been widely validated for long dependencies.

      In summary, Quest's primary innovation lies in its query-aware dynamic selection of KV cache pages for self-attention, coupled with an efficient page-level approximation mechanism, allowing it to achieve significant inference speedups while maintaining high accuracy on long-context tasks by preventing the premature loss of critical information.

4. Methodology

4.1. Principles

The core principle behind Quest stems from two key observations regarding long-context LLM inference:

  1. Sparsity in Self-Attention: Not all tokens in the Key-Value (KV) cache contribute equally to the self-attention outcome. A small subset of critical tokens often dominates the attention scores and is sufficient for accurate token generation. This suggests that self-attention can be made sparse, focusing computation and memory access only on these critical tokens.

  2. Query-Dependent Criticality: The criticality of a token is not static but dynamically changes based on the current query token. A token that was irrelevant for previous queries might become highly important for the current one, and vice-versa. This implies that query-agnostic KV cache eviction or pruning strategies are suboptimal and can lead to loss of accuracy, especially for long-range dependencies.

    Based on these principles, Quest proposes a query-aware KV cache selection algorithm. Instead of discarding KV cache entries, Quest maintains all KV cache pages but dynamically estimates the criticality of each page based on the current query and lightweight metadata. It then performs self-attention only on a Top-K subset of these estimated critical pages, thereby drastically reducing memory movement and computational overhead during the decode stage without losing potentially critical information. The intuition is to approximate which pages contain keys that would yield high attention scores with the current query, and only load those for the full attention computation.

4.2. Core Methodology In-depth (Layer by Layer)

The methodology of Quest can be broken down into an analysis of inference cost, the nature of self-attention sparsity, the query-dependent criticality observation, and then the detailed design of the Quest algorithm itself, including its criticality estimation, Top-K selection, and memory movement reduction.

4.2.1. Long-Context Inference Is Costly

LLM inference typically proceeds in two stages:

  • Prefill Stage: When a prompt (input sequence) is first fed into the model, all input tokens are processed in parallel. This involves transforming them into embeddings, computing their Query (Q), Key (K), and Value (V) vectors. The Key and Value vectors are stored in the KV cache. Self-attention and Feed-Forward Network (FFN) layers then produce the first response token. This stage happens only once per request.

  • Decode Stage: After the first token, subsequent tokens are generated one by one in an auto-regressive manner. For each new token, the model takes the previously generated token as input. It computes its QQ, KK, and VV vectors. The QQ vector then interacts with all KK vectors of previous tokens stored in the KV cache to compute attention weights. These attention weights are then normalized using softmax, where each value aia_i represents the attention score between the ii-th previous token and the current token. Finally, the self-attention layer outputs a weighted sum aiVi\sum a_i \cdot V_i (using the Value vectors from the KV cache), which is then passed to the FFN to predict the next token. This decode stage happens for every token in the response.

    The decode stage typically dominates inference time for LLMs, especially when generating long responses, because it's an iterative process. For example, if a prompt is 16K tokens and the response is 512 tokens, over 86% of the time can be spent in the decode stages. A major bottleneck in the decode stage for long-context LLMs is the KV cache itself. For a 32K context length, the KV cache can occupy 16GB of memory. Loading this large KV cache for every self-attention operation in the decode stage can consume up to 53% of the time, making optimizing self-attention crucial for efficient long-context inference.

4.2.2. Self-Attention Operation Features High Sparsity

Despite the large size of the KV cache, empirical research has shown that self-attention is inherently sparse. This means that only a small fraction of the tokens in the KV cache are truly critical and contribute significantly to the attention outcomes. These critical tokens accumulate high attention scores and capture the most important inter-token relationships. For example, Figure 3 illustrates that for the LongChat-7B model, beyond the initial two Transformer layers, over 90% of the KV cache tokens can be ignored (i.e., only less than 10% are needed) to achieve similar accuracy (measured by perplexity increase of less than 0.01). This high degree of sparsity presents a significant opportunity for optimization: if critical tokens can be accurately identified, self-attention can be computed only on these critical KV cache tokens, substantially reducing memory movement and improving efficiency.

The following figure (Figure 3 from the original paper) shows the query aware sparsity for each layer in LongChat-7B model:

Figure 3. The query aware sparsity for each layer in LongChat-7B model. We measure the sparsity by eliminating KV cache tokens while making sure the perplexity on PG19 increases less than 0.01. For the first two layers, the sparsity is below \(10 \\%\) , while for the rest of the layers, the sparsity is larger than \(90 \\%\) , showing great potential for optimization. Quest closely aligns with the oracle. 该图像是图表,展示了 LongChat-7B 模型中各层的查询感知稀疏性。横轴为变压器层,纵轴为查询感知稀疏性,高手段显示了 Oracle 和 Quest 估计的稀疏性。大部分层的稀疏性都大于 90%90\%,表明优化潜力。

4.2.3. Critical Tokens Depend on the Query

A crucial observation made by the authors is that the criticality of tokens is not fixed but is highly dynamic and dependent on the query vector (Q). This insight is central to Quest. Figure 2 provides a clear example: The prompt "A is B. C is D. A is" is given.

  • When the query is DD, token BB has a low attention score, indicating it's not critical.

  • However, when the query is the final is (expecting BB as the answer), token BB receives a very high attention score, making it highly critical.

    This demonstrates that pre-determining token criticality based on historical information or static rules is problematic because a token's importance can change based on the current context and the specific query. Methods that evict tokens based on historical attention scores (like H2O) or fixed sliding windows (StreamingLLM) risk discarding tokens that might become critical for a future query. Figure 4 further quantifies this by showing that Quest maintains a recall rate close to full attention (100%), whereas H2O suffers from low recall rates because critical tokens are often pruned in previous iterations. This strong correlation between token criticality and the query token motivates the need for query-aware sparsity to dynamically estimate criticality by considering the QQ vector.

The following figure (Figure 2 from the original paper) shows the attention map of prompt "A is B. C is D. A is":

Figure 2. The attention map of prompt "A is B. C is D. A is". Each row represents the attention scores of previous tokens queried by the tokens on the left. When queried with "D", token "B" has a low attention score, showing "B" is not critical for generation. However, the "is" strongly attends to "B". Therefore, the criticality of tokens strongly correlates with the current query token. 该图像是图表,展示了提示 "A is B. C is D. A is" 的注意力图。每行代表左侧标记的当前查询的前一个标记的注意力得分。当查询 "D" 时,标记 "B" 的注意力得分较低,表明 "B" 在生成中并不关键。但是,"is" 对 "B" 的注意力较强,因此标记的关键性与当前查询标记密切相关。

The following figure (Figure 4 from the original paper) shows the recall rate of tokens with Top-10 attention scores:

Figure 4. Recall rate of tokens with Top-10 attention scores. Results are profiled with LongChat-7b-v1.5-32k model in passkey retrieval test of 10K context length. Recall rate is the ratio of tokens selected by different attention methods to tokens selected by the full attention in each round of decoding. The average rate is shown in the figure, with various token budgets assigned. 该图像是图表,显示了不同 token 预算下 Top-10 回忆率的对比。结果包括 Full、Quest 和 H2O 方法,x 轴表示 token 预算(64、256 和 1024),y 轴表示回忆率。各方法在不同预算下的性能差异明显,Quest 方法在某些预算下表现较为优越。

4.2.4. Dynamically Estimating Token Criticality with Quest

To efficiently and accurately estimate the criticality of KV cache tokens in a query-aware manner, Quest introduces a novel algorithm. The workflow of Quest is depicted in Figure 5.

The following figure (Figure 5 from the original paper) shows the workflow of Quest:

该图像是一个示意图,展示了Quest算法在KV缓存选择中的工作流程。该流程分为两个阶段:第一阶段评估关键页面,包括元素乘积、每通道最大值和求和;第二阶段计算稀疏注意力。图中包含关键值、减少键、当前查询等信息,以及操作结果。 该图像是一个示意图,展示了Quest算法在KV缓存选择中的工作流程。该流程分为两个阶段:第一阶段评估关键页面,包括元素乘积、每通道最大值和求和;第二阶段计算稀疏注意力。图中包含关键值、减少键、当前查询等信息,以及操作结果。

To manage the overhead associated with KV cache management, Quest adopts the PageAttention concept, processing and selecting the KV cache at page granularity. A page is a fixed-size block containing multiple KV pairs.

The core of Quest is an approximate calculation of attention weights before the original attention operation, as detailed in Algorithm 1. The insight is to identify pages that are most likely to contain tokens with high attention weights for the current query. Instead of calculating actual attention scores for all tokens (which would defeat the purpose of sparsity), Quest estimates an upper bound for attention weights within each page.

Algorithm 1: Token Criticality Estimation This algorithm outlines two main phases: updating metadata when a new token is inserted and estimating criticality during self-attention.

Algorithm 1 Token Criticality Estimation

When inserting new token to KV cache:
Input: Key vector K, Dimension of hidden states dim, Current maximal vector M_i, Current minimal vector m_i
for i = 1 to dim do
    M_i = max(M_i, k_i)
    m_i = min(m_i, k_i)
end for

When perform self-attention:
Input: Query vector Q, Dimension of hidden states dim, Current maximal vector M_i, Current minimal vector m_i
Initialize score = 0.
for i = 1 to dim do
    score += MAX(q_i * max, q_i * min)
end for

Let's break down Algorithm 1 and its implications:

A. Metadata Update (When inserting new token to KV cache)

  • Input:
    • Key vector K: The Key vector of the new token being added to the KV cache.
    • Dimension of hidden states dim: The dimensionality of the Key vectors.
    • CurrentmaximalvectorMiCurrent maximal vector M_i: A vector (or array) storing the maximum Key value observed so far for each feature dimension i within the current KV cache page.
    • CurrentminimalvectormiCurrent minimal vector m_i: A vector (or array) storing the minimum Key value observed so far for each feature dimension i within the current KV cache page.
  • Process: For each dimension i from 1 to dim of the new Key vector k:
    • The maximal value for that dimension, MiM_i, is updated to be the maximum of its current value and the new token's Key value kik_i for that dimension: M_i = \max(M_i, k_i).
    • Similarly, the minimal value for that dimension, mim_i, is updated to be the minimum of its current value and the new token's Key value kik_i for that dimension: m_i = \min(m_i, k_i).
  • Purpose: This step efficiently maintains channel-wise minimum and maximum Key values for each KV cache page. This metadata (mim_i and MiM_i for all dimensions ii) is a lightweight summary of the range of Key values present in each page. It's much smaller than storing all Key vectors themselves.

B. Criticality Estimation (When perform self-attention)

  • Input:
    • Query vector Q: The Query vector of the current token for which self-attention is being computed.
    • Dimension of hidden states dim: The dimensionality of the Query and Key vectors.
    • CurrentmaximalvectorMiCurrent maximal vector M_i: The maximal Key values for each dimension within the page, computed in the previous step.
    • CurrentminimalvectormiCurrent minimal vector m_i: The minimal Key values for each dimension within the page, computed in the previous step.
  • Process:
    • Initialize score=0score = 0. This score will accumulate the approximate criticality for the current page.
    • For each dimension i from 1 to dim of the Query vector q:
      • Calculate U_i = \max(q_i \cdot M_i, q_i \cdot m_i). This UiU_i represents the maximum possible dot product contribution from the ii-th dimension of the Key vector within the page, regardless of the sign of qiq_i.
        • If qi>0q_i > 0, then qiMiq_i \cdot M_i will be the maximum product.
        • If qi<0q_i < 0, then qimiq_i \cdot m_i (multiplying a negative qiq_i by a smaller, potentially more negative, mim_i) will result in a larger (less negative or more positive) product.
        • If qi=0q_i = 0, both terms are 0, and Ui=0U_i=0.
      • Add UiU_i to the score: score+=Uiscore += U_i.
  • Purpose: The accumulated score represents an upper bound on the potential attention score that any Key vector within that page could achieve when multiplied by the current Query vector Q. By summing the maximum possible contribution from each dimension, Quest obtains a conservative estimate of the page's criticality. This allows Quest to confidently identify pages that might contain a highly critical token without actually performing the full dot product for every token.

C. Top-K Page Selection and Approximate Self-Attention

After calculating these upper bound criticality scores for all KV cache pages, Quest then selects the Top-K pages with the highest scores. Here, KK is a predefined hyperparameter (e.g., 128, 256, 4096). The actual self-attention computation is then performed only on the Key and Value vectors within these Top-K selected pages. The tokens in these selected pages form the "Token Budget."

Quest does not apply sparsity to the first two Transformer layers of the model. As indicated by Figure 3, these initial layers show low sparsity ratios (below 10%), meaning most tokens are still important. Applying Quest only to later layers where sparsity is high (over 90%) helps preserve model accuracy while still achieving significant speedups. This design choice is orthogonal to the KV cache selection algorithm itself.

4.2.5. Quest Reduces the Memory Movement of Self-Attention

The main efficiency gain of Quest comes from significantly reducing the amount of data that needs to be loaded from High-Bandwidth Memory (HBM) (e.g., GPU global memory) to faster SRAM (e.g., GPU shared memory or L1/L2 cache) for self-attention computation.

Let's assume:

  • Each Key or Value vector is MM bytes.

  • The total KV cache contains LL tokens.

  • Each page contains SS KV pairs (tokens).

    The total size of the full KV cache (if all tokens were loaded) would be 2ML2 \cdot M \cdot L bytes (since both Key and Value vectors are stored).

With Quest, the memory movement involves two parts:

  1. Criticality Estimation: For this, Quest only needs to load the minimal and maximal Key values (metadata) for each page. Each page has 2 vectors (min and max) of dimension dim (same as Key vector dimension). If dim is DkD_k, then the size of this metadata per page is 2Dksizeof(float)2 \cdot D_k \cdot \text{sizeof(float)} bytes. Across all pages, this is approximately 2M(L/S)2 \cdot M \cdot (L/S) bytes (assuming Key vector size MM). The paper simplifies this to 2ML/S2 M \cdot L/S bytes.

  2. Approximate Self-Attention: After selecting the Top-K pages, Quest loads the full Key and Value vectors only for these KK selected pages. This amounts to 2MKS2 \cdot M \cdot K \cdot S bytes.

    Therefore, the total memory loaded by Quest is (2ML/S)+(2MKS)(2 M \cdot L/S) + (2 M \cdot K \cdot S) bytes. The fraction of the total KV cache loaded by Quest is: $ \frac{(2 M \cdot L/S) + (2 M \cdot K \cdot S)}{2 M \cdot L} = \frac{L/S + K \cdot S}{L} = \frac{1}{S} + \frac{K \cdot S}{L} $ This can also be expressed as: $ \frac{1}{\mathrm{Page \thinspace Size}} + \frac{K}{\mathrm{Page \thinspace Num}} $ Where Page Size is SS and Page Num is L/S.

Example: If Page Size (SS) = 16 KV pairs per page, context length (LL) = 64K tokens, and Top-K selected pages (KK) = 4K. The number of pages (Page Num) would be 64000/16=400064000 / 16 = 4000. The fraction of KV cache loaded is approximately 116+40001664000=116+6400064000=116+1\frac{1}{16} + \frac{4000 \cdot 16}{64000} = \frac{1}{16} + \frac{64000}{64000} = \frac{1}{16} + 1. Wait, this calculation in the paper's text has a slight ambiguity. "choose the top 4K pages" might mean 4K tokens or 4K pages. Given KK is defined as "Top-K critical KV cache pages", it usually means KK pages. If KK is 4K pages, and Page Num is 4K, then K/Page Num=1K / \text{Page Num} = 1. The formula 1PageSize+KPageNum\frac{1}{\mathrm{Page \thinspace Size}} + \frac{K}{\mathrm{Page \thinspace Num}} would be 116+40004000=116+1=1.0625\frac{1}{16} + \frac{4000}{4000} = \frac{1}{16} + 1 = 1.0625. This would mean loading more than the full cache, which is incorrect.

Let's re-interpret the paper's example Assuming that we use 16 KV pairs per page, context length is 64K, and we choose the top 4K pages. This must mean we choose 4K tokens, which means K×SK \times S tokens are selected. If the "token budget" is BB, then the number of selected pages is B/S. So, if token budget is B=4096B=4096 tokens (as in some experiments) and Page Size is S=16S=16, then the number of selected pages is B/S=4096/16=256B/S = 4096/16 = 256 pages.

Let's re-evaluate the memory load fraction with Kpages=number of selected pagesK_{pages} = \text{number of selected pages}: Fraction loaded = 1Total Number of Pages+Number of Selected PagesTotal Number of Pages\frac{1}{\text{Total Number of Pages}} + \frac{\text{Number of Selected Pages}}{\text{Total Number of Pages}}. Total Number of Pages =L/S=64000/16=4000= L/S = 64000/16 = 4000. Suppose we select KpagesK_{pages} pages. Fraction loaded = 14000+KpagesSL\frac{1}{4000} + \frac{K_{pages} \cdot S}{L}. If we select KpagesK_{pages} pages, then KpagesSK_{pages} \cdot S is the total token budget. The paper's claim Quest will reduce the memory load by 8x implies a fraction of 1/81/8. This reduction is achieved if 1S+KSL18\frac{1}{S} + \frac{K \cdot S}{L} \approx \frac{1}{8}. If S=16S=16, then 1/S=1/161/S = 1/16. To get 1/81/8, KSL\frac{K \cdot S}{L} would need to be around 1/161/16. So KS=L/16K \cdot S = L/16. This means the token budget is 1/161/16 of the total context length. For L=64KL=64K, a token budget of 4K (i.e., 4096) would be 1/161/16 of 64K. In this case, the number of selected pages would be 4096/16=2564096/16 = 256. The formula given in the paper: 1PageSize+KPageNum\frac{1}{\mathrm{Page \thinspace Size}} + \frac{K}{\mathrm{Page \thinspace Num}} If KK here refers to the token budget and not number of pages, then it could be Ktokens=Kpages×SK_{tokens} = K_{pages} \times S. Let's assume KK in the formula is the number of selected pages. Then, if Total Page Num is 4000, and Selected Pages is 256 (for a 4K token budget), the formula is 116+2564000=0.0625+0.064=0.1265\frac{1}{16} + \frac{256}{4000} = 0.0625 + 0.064 = 0.1265. This is approximately 1/81/8. So the memory load reduction is approximately 1/81/8. This means Quest loads about 1/81/8 of the total KV cache, leading to an 8×8\times reduction in memory movement. This memory load reduction is described as universal across models and compatible with existing quantization mechanisms.

4.3. Quest and Other Optimizations

  • CUDA Kernels: To demonstrate feasibility and achieve high performance, Quest is implemented with dedicated CUDA kernels. This means specialized low-level code optimized for NVIDIA GPUs, built upon existing efficient libraries like FlashInfer.
  • Quantization Compatibility: The paper notes that Quest is compatible with quantization mechanisms, such as 4-bit weight quantization, further enhancing efficiency.

5. Experimental Setup

5.1. Datasets

Quest is evaluated on a range of long-context benchmarks to assess both its accuracy and efficiency.

  • PG19 (Rae et al., 2019):

    • Description: A large dataset comprising 100 books, with an average length of 70,000 tokens. It is used for language modeling tasks.
    • Usage: Evaluates perplexity on LongChat-7b-v1.5-32k model, prompting with up to 32,000 tokens.
    • Why chosen: Suitable for testing language model performance on long sequences.
  • Passkey Retrieval Task (Peng et al., 2023):

    • Description: This synthetic task measures a model's ability to retrieve a specific passkey (a piece of information) embedded within a large amount of meaningless surrounding text. The passkey can be placed at varying depths within the text, testing the model's long-range dependency capabilities.
    • Usage:
      • 10K length test on LongChat-7b-v1.5-32k.
      • 100K length test on Yarn-Llama-2-7b-128k.
      • The input text containing the passkey and surrounding text is prefilled to the model. The question and instructions to retrieve the passkey are then fed token by token to simulate decoding and test the KV cache management.
    • Why chosen: Specifically designed to evaluate long-dependency handling, where KV cache eviction methods often struggle due to discarding critical information.
  • LongBench (Bai et al., 2023):

    • Description: A comprehensive bilingual, multitask benchmark for long-context understanding. It includes various tasks across different domains and types of long-context reasoning.
    • Usage: Quest and baselines are evaluated on six specific datasets from LongBench using LongChat-7b-v1.5-32k:
      1. NarrativeQA (Koisky et al., 2018): Single-document Question Answering (QA).
      2. HotpotQA (Yang et al., 2018): Multi-document QA requiring multi-hop reasoning.
      3. Qasper (Dasigi et al., 2021): QA anchored in research papers.
      4. TrivialQA (Joshi et al., 2017): QA requiring knowledge retrieval.
      5. GovReport (Huang et al., 2021): Summarization task for government reports.
      6. MultifieldQA (Bai et al., 2023): Single-document QA.
    • Data Split and Inference Strategy: For LongBench tasks, the input is split into material and question/instruction. For the material part, FlashAttention with the full KV cache is used. For the question part, decoding is simulated by feeding tokens one by one to the model, testing the KV cache management.
    • Why chosen: Provides a diverse and general set of long-context tasks to validate the method's effectiveness across different NLP applications.

5.2. Evaluation Metrics

The paper uses several metrics to evaluate the accuracy and efficiency of Quest.

  • Perplexity (PPL):

    • Conceptual Definition: Perplexity is a measure of how well a probability model predicts a sample. In language modeling, it quantifies how surprised the model is by a given sequence of words. A lower perplexity score indicates that the model is more confident and accurate in its predictions of the next token in a sequence. It's inversely related to the probability of the text sequence according to the model.
    • Mathematical Formula: $ \mathrm{PPL}(W) = \exp\left(-\frac{1}{N} \sum_{i=1}^N \log P(w_i | w_1, ..., w_{i-1})\right) $
    • Symbol Explanation:
      • W=(w1,w2,...,wN)W = (w_1, w_2, ..., w_N): A sequence of NN tokens.
      • NN: The total number of tokens in the sequence.
      • P(wiw1,...,wi1)P(w_i | w_1, ..., w_{i-1}): The probability of token wiw_i given the preceding tokens w1,...,wi1w_1, ..., w_{i-1}, as assigned by the language model.
      • log\log: The natural logarithm.
      • exp\exp: The exponential function (exe^x).
  • Accuracy (for Passkey Retrieval Task):

    • Conceptual Definition: In the passkey retrieval task, accuracy measures whether the model correctly identifies and outputs the embedded passkey. It's a binary metric: either the passkey is retrieved correctly, or it's not.
    • Mathematical Formula: $ \text{Accuracy} = \frac{\text{Number of correctly retrieved passkeys}}{\text{Total number of passkey retrieval attempts}} $
    • Symbol Explanation:
      • Number of correctly retrieved passkeys: Count of instances where the model's output matches the target passkey.
      • Total number of passkey retrieval attempts: Total number of times the model was asked to retrieve a passkey.
  • F1 Score (for LongBench QA tasks):

    • Conceptual Definition: The F1 score is a measure of a model's accuracy on a dataset, particularly useful for question answering and information retrieval tasks where there might be an imbalance between positive and negative classes. It combines precision and recall into a single metric.
      • Precision: The fraction of retrieved instances that are relevant (true positives out of all positives predicted by the model).
      • Recall: The fraction of relevant instances that are retrieved (true positives out of all actual positives in the dataset).
    • Mathematical Formula: $ \text{F1 Score} = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} $ Where: $ \text{Precision} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Positives}} $ $ \text{Recall} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}} $
    • Symbol Explanation:
      • True Positives (TP): Correctly identified relevant items.
      • False Positives (FP): Incorrectly identified relevant items (model predicted relevant, but it was not).
      • False Negatives (FN): Relevant items that were missed (model predicted not relevant, but it was).
  • Latency / Speedup:

    • Conceptual Definition:
      • Latency: The time taken to complete an operation, such as generating a single token in the decode stage or executing a specific CUDA kernel. Measured in milliseconds (μ\mus or ms). Lower latency is better.
      • Speedup: The factor by which a process is accelerated. If a new method takes TnewT_{new} time and a baseline takes TbaselineT_{baseline} time, the speedup is Tbaseline/TnewT_{baseline} / T_{new}. For example, 2×2\times speedup means the new method is twice as fast.
    • Usage: Measured at both kernel-level (individual operations like criticality estimation, Top-K filtering, approximate attention) using NVBench and end-to-end (decode stage token generation) using PyTorch profiler.
  • Recall Rate (for Top-10 Attention Scores):

    • Conceptual Definition: In the context of critical tokens, recall rate measures how many of the truly critical tokens (identified by a full attention model) are successfully identified and selected by the proposed sparse attention method. A high recall rate indicates that the method effectively retains the most important tokens.
    • Mathematical Formula: $ \text{Recall Rate} = \frac{\text{Number of critical tokens selected by sparse method}}{\text{Number of critical tokens selected by full attention}} $
    • Symbol Explanation:
      • Number of critical tokens selected by sparse method: Count of tokens chosen by Quest (or other sparse methods) that also belong to the critical tokens set identified by full attention.
      • Number of critical tokens selected by full attention: Count of Top-10 attention score tokens identified when self-attention is computed over the entire KV cache.

5.3. Baselines

Quest is compared against several representative KV cache eviction or sparse attention methods, as well as FlashInfer for efficiency benchmarks.

  • H2O (Heavy-Hitter Oracle): (Zhang et al., 2023b) An KV cache eviction algorithm that prunes tokens based on their historical attention scores. It retains a limited budget of tokens that have been consistently "heavy-hitters."

    • Note for Evaluation: For long-context evaluation (e.g., 100K passkey retrieval), H2O typically requires computing the full O(N2)O(N^2) attention map to collect historical attention scores, which can be slow. To enable its comparison in long-context scenarios, FlashAttention is used during the context stage, and H2O only starts collecting historical attention scores during the decoding stage.
  • TOVA (Transformers are Multi-State RNNs): (Oren et al., 2024) A KV cache eviction method that decides which tokens to permanently discard based on the current query.

  • StreamingLLM: (Xiao et al., 2023) Handles infinitely long texts by maintaining a fixed-size KV cache that includes attention sinks (initial important tokens) and a sliding window of recent tokens, discarding older tokens outside this window.

  • FlashInfer (Ye et al., 2024): A highly optimized kernel library for LLM inference. Quest is implemented on top of FlashInfer, and FlashInfer's standard attention implementation serves as a strong efficiency baseline for Quest's kernel-level and end-to-end speedup comparisons. It represents the state-of-the-art for dense attention.

    Important Note on Baseline Application: For fair comparison and to preserve model accuracy, Quest and all baselines are not applied to the first two Transformer layers of the models. This is because the initial layers exhibit low sparsity ratios (as shown in Figure 3), meaning most tokens are still important. Using a full KV cache for these initial layers helps maintain overall model performance.

5.4. Models

Two widely used long-context LLMs are employed for evaluation:

  • LongChat-v1.5-7b-32k (Li et al., 2023): A 7-billion parameter LLM fine-tuned to support a 32K context length. Used for PG19, 10K passkey retrieval, and LongBench evaluations.
  • Yarn-Llama-2-7b-128k (Peng et al., 2023): A 7-billion parameter LLM extended to support a 128K context length using the Yarn method. Used for 100K passkey retrieval evaluation.

5.5. Hardware

The experiments are conducted on NVIDIA GPUs:

  • RTX4090: Used for kernel-level efficiency evaluation (Section 4.3.1).
  • Ada 6000 GPU (NVIDIA, 2023): Used for end-to-end evaluations for longer context lengths (Section 4.3.2).

6. Results & Analysis

6.1. Core Results Analysis

The experimental results demonstrate that Quest effectively achieves significant inference speedups while maintaining high accuracy across various long-context benchmarks.

6.1.1. Language Modeling on PG19

The PG19 dataset is used to evaluate language modeling perplexity. Quest is compared against H2O and TOVA (both with a full cache in the first two layers, denoted with *), and a Full cache baseline. The token budget for H2O, TOVA, and Quest is 4096 tokens (approximately 1/8 of the total 32K context length).

The following figure (Figure 6 from the original paper) shows the language modeling evaluation of Quest on PG19 dataset:

Figure 6. Language modeling evaluation of Quest on PG19 dataset. We prompt the model with 0 to 32000 tokens from the PG19 test set and measure the perplexity of output tokens. \(\\mathrm { H } 2 \\mathrm { O } ^ { \\ast }\) and TOVA\\* indicate that for the first two layers of models, we do not apply these two algorithms to prune the KV Cache, as analyzed in Sec 3.4, which better preserves the model performance. Quest also uses a full cache in the first two layers of the model. Quest can closely match the performance of the full cache model. 该图像是图表,展示了Quest在PG19数据集上的语言建模评估结果。图中横轴表示输入长度,纵轴为输出的困惑度(perplexity),曲线分别代表不同算法的性能,包括H2O、TOVA*、完整缓存(Full)和Quest(本研究提出的方法)。可以看到,Quest的表现与完整缓存相近,且在长输入序列上保持较低的困惑度。插图部分详细展示了在输入长度接近32000时的困惑度变化情况。*

As shown in Figure 6, Quest consistently maintains perplexity scores that closely match the oracle baseline (the Full cache model) across varying input lengths up to 32,000 tokens. This indicates that Quest's query-aware selection mechanism effectively identifies the critical KV cache pages needed to preserve the language model's predictive performance, even with a significantly reduced KV cache budget. In contrast, H2OH2O* and TOVATOVA*, despite also using a full cache in the initial layers, show noticeably higher perplexity as the input length increases, signifying a greater degradation in accuracy due to their KV cache pruning strategies.

6.1.2. Results on Long Text Passkey Retrieval Task

This task specifically tests the model's ability to handle long-distance dependencies. Quest is evaluated on 10K and 100K context lengths and compared against H2O, TOVA, and StreamingLLM.

The following are the results from Table 1 of the original paper:

Method / Budget 32 64 128 256 512
H2O 0% 1% 1% 1% 3%
TOVA 0% 1% 1% 3% 8%
StreamingLLM 1% 1% 1% 3% 5%
Quest (ours) 65% 99% 99% 99% 100%
Method / Budget 256 512 1024 2048 4096
H2O 2% 2% 2% 2% 4%
TOVA 2% 2% 2% 2% 10%
StreamingLLM 1% 1% 1% 2% 4%
Quest (ours) 88% 92% 96% 100% 100%

Table 1 Summary:

  • (i) 10k length passkey retrieval test on LongChat-7b-v1.5-32k: With a budget of 64 tokens, Quest achieves 99% accuracy, rising to 100% with 512 tokens. In contrast, H2O, TOVA, and StreamingLLM perform very poorly, achieving only 1-8% accuracy even with a 512-token budget. This clearly shows that Quest effectively preserves the model's ability to retrieve long-distance information.

  • (ii) 100k length passkey retrieval test on Yarn-Llama-2-7b-128k: Similar trends are observed at a much longer context. With a budget of 1024 tokens, Quest reaches 96% accuracy, achieving 100% with 2048 tokens. The baselines again struggle, with H2O, TOVA, and StreamingLLM reaching at most 10% accuracy even with a 4096-token budget.

    The results highlight a critical limitation of KV cache eviction algorithms: when simulating decoding (feeding the question token by token), these methods mistakenly discard critical tokens (like the passkey) because they are not query-aware or their fixed windows don't cover the passkey. Quest, by dynamically identifying critical tokens based on the current query without discarding anything from the total cache, avoids this issue and demonstrates near-perfect accuracy with minimal token budgets (e.g., 64 tokens is about 0.6% of 10K, 1024 tokens is about 1% of 100K).

6.1.3. Results on LongBench

Quest is evaluated on six LongBench datasets to assess its generalizability across diverse long-context tasks.

The following figure (Figure 7 from the original paper) shows the F1 scores for the algorithms StreamingLLM, H2O, TOVA, Quest (our algorithm), and Full under different KV cache budgets across multiple experimental tasks:

该图像是多组实验结果的折线图,显示了在不同 KV 缓存预算下,算法 StreamingLLM、H2O、TOVA、Quest(我们的算法)与 Full 的 F1 分数对比。每个子图分别对应不同的任务,包括 Qasper、HotpotQA、GovReport、TriviaQA、NarrativeQA 和 MultifieldQA。 该图像是多组实验结果的折线图,显示了在不同 KV 缓存预算下,算法 StreamingLLM、H2O、TOVA、Quest(我们的算法)与 Full 的 F1 分数对比。每个子图分别对应不同的任务,包括 Qasper、HotpotQA、GovReport、TriviaQA、NarrativeQA 和 MultifieldQA。

Figure 7 shows that Quest consistently outperforms all baselines (StreamingLLM, H2O, TOVA) across all six LongBench datasets and various KV cache budgets. For instance, Quest with a budget of 1K tokens often achieves performance comparable to the Full cache model, whereas baselines show a significant accuracy gap even with larger budgets. Specifically, Quest achieves "lossless performance" (matching Full cache) on Qasper, HotpotQA, GovReport, TriviaQA, NarrativeQA, and MultifieldQA with very high KV cache sparsity ratios of 1/6, 1/6, 1/5, 1/10, 1/5, and 1/6 respectively (after considering the full cache used in the first two layers). This further validates that Quest maintains model capabilities across diverse long-context tasks by intelligently selecting critical pages, avoiding the generation of incorrect answers due to improper KV cache handling.

6.1.4. Efficiency Evaluation (Kernel Level)

Quest's efficiency is evaluated at the kernel level using NVBench on an RTX4090 GPU.

The following figure (Figure 8 from the original paper) shows Quest's kernel-level efficiency with different sequence lengths and page sizes:

该图像是图表,包括两个部分,分别为(a)关键性估计和(b)近似注意力。部分(a)展示了不同序列长度下,FlashInfer和Quest算法(Quest-4, Quest-8, Quest-16, Quest-32)的归一化延迟比较,部分(b)则显示了不同序列长度下的延迟时间,通过不同的Quest设置(Quest-512, Quest-1024, Quest-2048, Quest-4096, Quest-8192)进行比较。 该图像是图表,包括两个部分,分别为(a)关键性估计和(b)近似注意力。部分(a)展示了不同序列长度下,FlashInfer和Quest算法(Quest-4, Quest-8, Quest-16, Quest-32)的归一化延迟比较,部分(b)则显示了不同序列长度下的延迟时间,通过不同的Quest设置(Quest-512, Quest-1024, Quest-2048, Quest-4096, Quest-8192)进行比较。

  • Criticality Estimation (Figure 8a): The latency of criticality estimation is compared to FlashInfer's attention (normalized to 1). At short sequence lengths, Quest's estimation has lower memory bandwidth utilization as the total memory loaded is small. As sequence length increases, the relative performance of estimation improves and approaches 1/Page Size1 / \text{Page Size}. This is because estimation only needs to read metadata proportional to 1/Page Size1 / \text{Page Size} of the full KV cache. Quantization or larger page sizes can further reduce this overhead.

  • Top-K Filtering: The overhead of Top-K filtering (using RAFT library) is minimal, typically 5-10 μ\mus for sequence lengths up to 128K. This is because criticality estimation reduces each entire token to a single criticality score, leading to very limited memory movement for Top-K selection.

  • Approximate Attention (Figure 8b): Quest's approximate attention (performing attention only on selected pages) has a latency that is constant for a given token budget BB, regardless of the total sequence length. It achieves similar latency to FlashInfer when the sequence length is equal to the token budget BB.

    The overall self-attention mechanism of Quest (combining criticality estimation, Top-K filtering, and approximate attention) is profiled on Llama2-7B.

The following figure (Figure 9 from the original paper) shows the time breakdown of Quest on various sequence lengths:

该图像是一个示意图,展示了在不同序列长度下,使用 Quest 方法进行关键性估计、Top-K 过滤和近似注意力的延迟表现。图中显示,随着序列长度的增加,使用 Top-K 过滤和近似注意力方法,能显著降低延迟,提升效率。 该图像是一个示意图,展示了在不同序列长度下,使用 Quest 方法进行关键性估计、Top-K 过滤和近似注意力的延迟表现。图中显示,随着序列长度的增加,使用 Top-K 过滤和近似注意力方法,能显著降低延迟,提升效率。

Figure 9 shows that Quest significantly reduces self-attention time. For a sequence length of 32K with a token budget of 2048, Quest achieves a 7.03×7.03\times speedup compared to FlashInfer. This substantial speedup is attributed to the reduced memory movement achieved by query-aware sparsity.

6.1.5. End-to-End Evaluation

To demonstrate practical benefits, Quest is deployed in real-world single-batch scenarios and measures the average latency of generating one token in the decode stage.

The following figure (Figure 10 from the original paper) shows Quest's end-to-end speedup:

该图像是一个图表,展示了FlashInfer与不同上下文长度下的延迟对比。图中提供了FP16权重和4位权重(AWQ)的延迟数据,分别显示了在32768上下文长度时的延迟加速比为1.74x和2.23x。 该图像是一个图表,展示了FlashInfer与不同上下文长度下的延迟对比。图中提供了FP16权重和4位权重(AWQ)的延迟数据,分别显示了在32768上下文长度时的延迟加速比为1.74x和2.23x。

Figure 10 illustrates that Quest consistently outperforms FlashInfer (the full KV cache baseline) across all sequence lengths. Crucially, Quest's latency grows much slower than FlashInfer's as sequence length increases, because Quest maintains a similar token budget.

  • At a sequence length of 32K and a token budget of 2048, Quest boosts inference speed by 1.74×1.74\times with FP16 weights.
  • With 4-bit quantized weights, the speedup further increases to 2.23×2.23\times. This demonstrates that Quest provides substantial end-to-end latency reduction in practical LLM inference scenarios.

6.1.6. Comparison with Baselines (Efficiency vs. Accuracy)

The paper also provides a qualitative comparison of Quest's inference efficiency against baselines under the constraint of "lossless accuracy" (i.e., maintaining comparable accuracy to a full cache model) on the LongBench tasks.

The following figure (Figure 11 from the original paper) shows the inference latency of different attention methods for comparable accuracy:

该图像是一个图表,展示了在不同基准下,全量、TOVA和Quest方法的平均上下文长度和推理延迟。图中显示Quest方法在多项任务中表现出显著的延迟改善,最大化上下文长度的同时有效降低推理延迟,表明了其在长依赖任务上的优势。 该图像是一个图表,展示了在不同基准下,全量、TOVA和Quest方法的平均上下文长度和推理延迟。图中显示Quest方法在多项任务中表现出显著的延迟改善,最大化上下文长度的同时有效降低推理延迟,表明了其在长依赖任务上的优势。

  • Token Budgets for Lossless Accuracy (Figure 11a): This chart shows the token budgets required by different methods to achieve lossless accuracy on LongBench tasks. For example, on NarrativeQA (average context length 24K), TOVA requires a 14K token budget, while Quest needs only 5K. This indicates that Quest achieves much higher sparsity while maintaining accuracy.
  • Self-Attention Latency Comparison (Figure 11b): Given that baselines often lack kernel implementations, a qualitative analysis of self-attention efficiency is performed using FlashInfer's latency as a proxy, disregarding other runtime overheads of baselines. Quest, however, is evaluated in a practical setting considering all its operators. Figure 11b reveals that Quest significantly surpasses all baselines in self-attention latency due to its superior query-aware sparsity. For instance, on GovReport, Quest boosts inference by 3.82×3.82\times compared to TOVA for comparable accuracy. On TriviaQA, the speedup is 4.54×4.54\times. This confirms Quest's ability to provide higher efficiency while simultaneously maintaining superior accuracy.

6.2. Data Presentation (Tables)

The following are the results from Table 1 of the original paper:

Method / Budget 32 64 128 256 512
(i) Results of 10k length passkey retrieval test on LongChat-7b-v1.5-32k.
H2O 0% 1% 1% 1% 3%
TOVA 0% 1% 1% 3% 8%
StreamingLLM 1% 1% 1% 3% 5%
Quest (ours) 65% 99% 99% 99% 100%
Method / Budget 256 512 1024 2048 4096
(ii) Results of 100k length passkey retrieval test on Yarn-Llama-2-7b-128k. Quest can achieve nearly perfect accuracy with 64 and 1024 tokens KV cache budget, which is about 1% of the total sequence length, demonstrating that Quest can effectively preserve the model's ability to handle long-dependency tasks. However, KV cache eviction algorithms such as H2O, TOVA, and StreamingLLM incorrectly discard the KV cache of the answer before receiving the question, thus failing to achieve ideal accuracy.
H2O 2% 2% 2% 2% 4%
TOVA 2% 2% 2% 2% 10%
StreamingLLM 1% 1% 1% 2% 4%
Quest (ours) 88% 92% 96% 100% 100%

6.3. Ablation Studies / Parameter Analysis

The paper does not present explicit, dedicated ablation studies in the traditional sense (e.g., removing a component of Quest to show its contribution). However, the varying token budgets used across experiments (32 to 4096 for passkey retrieval, 512 to 4096 for LongBench) implicitly serve as a parameter analysis for the KK in Top-K selection.

The results (e.g., Table 1, Figure 7) consistently show that Quest maintains high accuracy even with very aggressive token budget reductions, demonstrating the robustness and effectiveness of its criticality estimation and Top-K selection mechanism. For instance, in the 10k passkey retrieval task, Quest achieves 99% accuracy with a budget of just 64 tokens, indicating that its query-aware sparsity can dramatically reduce the required KV cache size without compromising performance.

The choice of page size is also briefly mentioned in the kernel evaluation, where criticality estimation latency approaches 1/Page Size1/\text{Page Size}, implying that larger page sizes can further reduce the overhead of estimation. However, a detailed analysis of page size impact on accuracy or total speedup is not provided. The decision to skip the first two Transformer layers from Quest's application (due to their low sparsity) could also be considered a form of ablation or design choice based on observed sparsity patterns, highlighting an important hyper-parameter in its practical deployment.

7. Conclusion & Reflections

7.1. Conclusion Summary

This paper introduces Quest, a novel and efficient query-aware KV cache selection algorithm designed to accelerate long-context Large Language Model (LLM) inference. The core insight driving Quest is the observation that token criticality for self-attention is highly dynamic and depends on the current query. By leveraging this query-aware sparsity, Quest tracks minimal and maximal Key values as metadata for KV cache pages and uses the current Query vector to estimate each page's criticality. It then performs self-attention only on the Top-K critical KV cache pages, significantly reducing memory movement and computational load.

Comprehensive evaluations demonstrate Quest's effectiveness across language modeling, passkey retrieval, and LongBench tasks. It achieves remarkable self-attention speedups of up to 7.03×7.03\times, leading to an end-to-end inference latency reduction of 2.23×2.23\times (with 4-bit quantization). Crucially, Quest accomplishes this with negligible accuracy loss, often matching full cache performance even with highly aggressive KV cache budgets. Compared to prior KV cache eviction baselines, Quest consistently shows superior accuracy for a given budget and achieves up to 4.5×4.5\times greater self-attention latency reduction for equivalent accuracy targets. The dedicated CUDA kernel implementation further validates its practical feasibility and high performance.

7.2. Limitations & Future Work

The paper implicitly highlights some limitations and potential areas for future work:

  • Initial Layer Sparsity: Quest and other baselines are not applied to the first two Transformer layers due to their low sparsity ratio. While a practical design choice to preserve accuracy, it suggests that Quest's sparsity mechanism might not be universally applicable across all layers or might require further refinement for layers with less inherent sparsity.
  • Hyperparameter K and Page Size: The Top-K parameter (number of critical pages/tokens) is a preset constant. Dynamic adjustment of KK based on the context, task, or observed sparsity could potentially offer further optimizations or adaptability. Similarly, the impact of page size is only briefly discussed in terms of estimation overhead, without a detailed analysis of its optimal choice or effect on accuracy and speedup.
  • Generalizability to Diverse Architectures: While evaluated on Llama-2 variants, Quest's applicability to other LLM architectures (e.g., Mixtral, encoder-decoder models, or novel attention mechanisms) would require further investigation.
  • Overhead of Metadata Management: While described as lightweight, the continuous updating of min/max Key values for each page and the criticality estimation itself introduce some overhead. Optimizing these operations further, perhaps through hardware acceleration or more advanced data structures, could be a direction for future work.
  • Broader Task Coverage: Although LongBench is diverse, exploring Quest's performance on more specialized long-context tasks (e.g., code generation, scientific document understanding) could reveal further nuances.

7.3. Personal Insights & Critique

Quest presents a very elegant and practically impactful solution to the long-context LLM inference problem. The core insight that token criticality is query-dependent is powerful and intuitively makes sense, yet it's often overlooked by simpler eviction policies that rely on historical averages or fixed windows. The paper's empirical validation of this dynamic criticality (Figure 2 and Figure 4) is crucial and compelling.

The choice to operate at page granularity (building on PageAttention) is smart, as it aligns with modern LLM serving infrastructure and offers a good balance between fine-grained control and computational overhead. The min/max Key value approximation for criticality estimation is particularly clever. It provides a computationally cheap upper bound estimate that effectively prunes away irrelevant KV cache pages without needing to perform expensive dot products for every token in every page. This approximation is key to Quest's efficiency, as it avoids the O(N2)O(N^2) complexity of full attention or even full approximate attention.

The strong speedup results, especially the end-to-end latency reduction with quantization, demonstrate Quest's immediate practical value for deploying long-context LLMs more efficiently. This could significantly lower the operational costs and improve the responsiveness of LLM-powered applications, making long-context models more accessible.

A potential area for deeper exploration might be a more granular analysis of how the KK (number of selected pages) parameter affects accuracy and latency across different layers and query types. While the paper shows aggregated results for different budgets, understanding the layer-wise optimal KK could yield even better performance. Additionally, investigating whether the min/max approximation could be extended or combined with other lightweight statistics (e.g., mean, variance) to further refine criticality estimation might be interesting, though it would need to maintain its low overhead.

The concept of query-aware sparsity could also be transferred to other Transformer-based architectures or domains where attention mechanisms are a bottleneck, such as long-sequence vision transformers or time-series forecasting models. The principle of dynamically identifying and focusing on critical elements based on the current query is broadly applicable. Overall, Quest is a significant step towards enabling truly efficient and scalable long-context LLM inference.

Similar papers

Recommended via semantic vector search.

No similar papers found yet.