FlatQuant: Flatness Matters for LLM Quantization
TL;DR Summary
FlatQuant introduces a new post-training quantization method that optimizes the flatness of weights and activations, reducing quantization error significantly. It establishes a new benchmark for the LLaMA-3-70B model, achieving less than 1% accuracy drop and up to 2.3x speed impr
Abstract
Recently, quantization has been widely used for the compression and acceleration of large language models (LLMs). Due to the outliers in LLMs, it is crucial to flatten weights and activations to minimize quantization error with equally spaced quantization points. Prior research explores various pre-quantization transformations to suppress outliers, such as per-channel scaling and Hadamard transformation. However, we observe that these transformed weights and activations can still exhibit steep and dispersed distributions. In this paper, we propose FlatQuant (Fast and Learnable Affine Transformation), a new post-training quantization approach that enhances the flatness of weights and activations. Our approach identifies optimal affine transformations for each linear layer, calibrated in hours via a lightweight objective. To reduce runtime overhead of affine transformation, we apply Kronecker product with two lightweight matrices, and fuse all operations in FlatQuant into a single kernel. Extensive experiments demonstrate that FlatQuant establishes a new state-of-the-art benchmark for quantization. For example, it achieves less than 1% accuracy drop for W4A4 quantization on the LLaMA-3-70B model, surpassing SpinQuant by 7.5%. Additionally, it provides up to 2.3x prefill speedup and 1.7x decoding speedup compared to the FP16 model. Code is available at: https://github.com/ruikangliu/FlatQuant.
Mind Map
In-depth Reading
English Analysis
1. Bibliographic Information
1.1. Title
The central topic of the paper is FlatQuant: Flatness Matters for LLM Quantization. It proposes a novel approach to improve Large Language Model (LLM) quantization by enhancing the flatness of weights and activations.
1.2. Authors
The authors are Yuxuan Sun, Ruikang Liu, Haoli Bai, Han Bao, Kang Zhao, Yuening Li, Jiaxin Hu, Xianzhi Yu, Lu Hou, Chun Yuan, Xin Jiang, Wulong Liu, and Jun Yao. They are affiliated with various institutions, with a significant presence from what appears to be a research institution or university (indicated by '1', '2', '3' affiliations, typically denoting different departments or organizations). For example, Yuxuan Sun and Haoli Bai are from affiliation '1', Ruikang Liu from '2', and Yuening Li from '3'. The specific organizations are not explicitly named in the provided text but are typical for academic research collaborations.
1.3. Journal/Conference
The paper is published as a preprint on arXiv (arXiv preprint arXiv:2410.09426v4). While arXiv is not a peer-reviewed journal or conference, it is a highly influential open-access repository for scholarly articles, particularly in physics, mathematics, computer science, quantitative biology, quantitative finance, statistics, electrical engineering and systems science, and economics. Papers on arXiv are widely circulated and serve as a crucial platform for disseminating cutting-edge research before or during the peer-review process, allowing for rapid sharing of findings and feedback from the scientific community.
1.4. Publication Year
The paper was published at 2024-10-12T08:10:28.000Z (UTC), which corresponds to October 12, 2024.
1.5. Abstract
This paper introduces FlatQuant (Fast and Learnable Affine Transformation), a new post-training quantization (PTQ) method for compressing and accelerating Large Language Models (LLMs). The core idea is that flatness (i.e., less steep and dispersed distributions) of weights and activations is crucial for minimizing quantization error, especially given the presence of outliers in LLMs. While prior methods use pre-quantization transformations like per-channel scaling and Hadamard transformations, the authors observe these can still result in sub-optimal distributions. FlatQuant proposes to learn optimal affine transformations for each linear layer, calibrated efficiently within hours using a lightweight objective function. To mitigate runtime overhead, it employs a Kronecker product with two lightweight matrices and fuses all operations into a single kernel. Extensive experiments demonstrate that FlatQuant achieves state-of-the-art results, such as less than 1% accuracy drop for W4A4 quantization on the LLaMA-3-70B model, outperforming SpinQuant by 7.5%. Furthermore, it provides up to 2.3x prefill speedup and 1.7x decoding speedup compared to the FP16 model.
1.6. Original Source Link
The official source link is: https://arxiv.org/abs/2410.09426v4
The PDF link is: https://arxiv.org/pdf/2410.09426v4.pdf
The paper is available as a preprint on arXiv.
2. Executive Summary
2.1. Background & Motivation
The rapid growth in the size of Large Language Models (LLMs) (e.g., billions of parameters) has led to significant computational and memory overhead during inference. This makes deploying LLMs on resource-constrained devices or in high-throughput scenarios challenging and expensive. Quantization is a highly effective technique to address this by reducing the precision of model parameters (weights) and intermediate computations (activations) from full-precision (e.g., FP16 or FP32) to lower-precision formats (e.g., INT4 or INT8).
The core problem in LLM quantization is minimizing quantization error. This error arises because low-bit quantization uses a limited number of discrete points to represent a continuous range of values. A critical challenge for LLMs, specifically, is the presence of outliers—extreme values in both weights and activations. These outliers can cause large quantization errors when mapped to the same low-bit integer points, leading to significant performance degradation.
Prior research has explored various pre-quantization transformations to suppress outliers and flatten distributions, such as per-channel scaling (e.g., SmoothQuant) and Hadamard transformation (e.g., QuaRot, SpinQuant). While these methods offer improvements, the authors observe that they often remain sub-optimal, with transformed weights and activations still exhibiting steep and dispersed distributions. This sub-optimality limits the achievable accuracy and inference speedup.
The paper's entry point and innovative idea is to explicitly focus on enhancing the flatness of weight and activation distributions through learnable affine transformations. By making these transformations fast and adaptive for each linear layer, FlatQuant aims to systematically mitigate outliers and create distributions that are inherently easier to quantize with minimal error. This adaptive, learnable approach, combined with efficient implementation strategies, seeks to overcome the limitations of fixed or less flexible pre-quantization methods.
2.2. Main Contributions / Findings
The paper makes several significant contributions:
- Highlighting Flatness Significance: It rigorously demonstrates that achieving flatter distributions of weights and activations is crucial for effective LLM quantization, as it reduces quantization error and controls its propagation across
Transformer layers. - Introducing FlatQuant: The paper proposes
FlatQuant, a novelpost-training quantizationmethod. This approach usesfast and learnable affine transformationsoptimized specifically for each linear layer. Empirically,FlatQuantis shown to significantly enhance theflatnessof both weights and activations in LLMs. - New State-of-the-Art Accuracy:
FlatQuantestablishes new state-of-the-art results for LLM quantization accuracy. Notably, it is presented as the first method to achieve less than1%accuracy drop with simpleround-to-nearest (RTN)W4A4quantization on theLLaMA-3-70Bmodel, a highly challenging benchmark. This significantly surpasses previous methods likeSpinQuantby a large margin (e.g.,7.5%). - Efficient Kernel Design and Speedup: The authors design an
efficient kernelthatfusestheaffine transformationsandquantizationoperations. This design minimizesglobal memory accessandkernel launch overhead, leading to substantial inference speedups. Specifically,FlatQuantachieves up to2.3xprefill speedup and1.7xdecoding speedup underW4A4quantization when compared to anFP16baseline. - Robustness and Versatility:
FlatQuantdemonstrates robustness across different LLM families (LLaMA, Qwen, DeepSeek), various tasks (language modeling, QA, MT-bench), and is compatible with different quantization settings (weight-only, KV cache quantization). Its calibration process is also efficient, taking only hours on a single GPU.
3. Prerequisite Knowledge & Related Work
3.1. Foundational Concepts
To understand FlatQuant, a novice reader should grasp the following foundational concepts:
- Large Language Models (LLMs): These are advanced artificial intelligence models, typically based on the
Transformer architecture, trained on vast amounts of text data. They are designed to understand, generate, and process human language, excelling at tasks like text generation, question answering, summarization, and translation. Theirparametercount (weights and biases) can range from billions to trillions, making them computationally intensive and memory hungry. - Quantization: This is a model compression technique that reduces the precision of numbers used to represent a neural network's weights and activations. For instance,
FP16(16-bit floating-point) orFP32(32-bit floating-point) values might be converted toINT4(4-bit integer) orINT8(8-bit integer). The primary goals are to:- Reduce Memory Footprint: Lower-precision numbers require less storage, allowing larger models to fit into available memory or run on devices with limited resources.
- Accelerate Inference: Low-bit arithmetic operations are generally faster and consume less power than high-precision floating-point operations, leading to quicker model responses.
- Quantization Error: The process of mapping continuous high-precision values to discrete low-precision values inevitably introduces
quantization error. The goal of quantization research is to minimize this error so that the low-precision model's performance (e.g., accuracy, perplexity) remains close to that of its full-precision counterpart.
- Weights and Activations:
- Weights (): These are the learnable parameters within a neural network. They determine the strength of connections between neurons and are adjusted during the training process.
- Activations (): These are the outputs of neurons or layers at each step of the network's computation. They represent the data flowing through the network. In an
LLM, activations are typically the output of alinear layeror theattention mechanism.
- Outliers: In the context of LLMs,
outliersrefer to a small number of extremely large or small values present in the weight matrices or activation tensors. These extreme values are problematic for quantization because a fixed-range, equally-spaced quantization scheme will allocate many of its limited discrete points to represent these few outliers, leaving fewer points for the majority of "normal" values. This leads to highquantization errorfor the non-outlier values, significantly degrading model performance. - Flatness of Distributions: A
flatdistribution (also referred to as having lowkurtosis) means that the values within a tensor (weights or activations) are relatively uniformly distributed across their range, without excessively sharp peaks or dispersed extreme values. Conversely, asteepordisperseddistribution implies the presence of many outliers or values concentrated in a narrow range with a long tail. For quantization, a flatter distribution is highly desirable because:- It allows
equally spaced quantization pointsto cover the value range more effectively. - It minimizes the impact of outliers, as values are more evenly spread out.
- This directly leads to lower
quantization error.
- It allows
- Transformer Architecture: The foundational architecture for most modern LLMs. Key components include:
- Linear Layers: Also known as
fully connected layers, these perform matrix multiplication (e.g., ). They are a primary target for quantization. - Self-Attention: A mechanism that allows the model to weigh the importance of different parts of the input sequence when processing each token. It involves
Query (Q),Key (K), andValue (V)projections, which are typically implemented using linear layers. - Feed-Forward Networks (FFN): A stack of linear layers (often two or three) with non-linear activation functions, applied independently to each position in the sequence.
- Layer Normalization (LayerNorm): A technique used to stabilize the training of deep neural networks by normalizing the inputs to each layer across the features.
- Residual Connections: Add the input of a sub-layer to its output, helping to mitigate the vanishing gradient problem in deep networks.
- KV Cache: During
LLMinference, especially in thedecoding stage(when generating one token at a time), theKeyandValuetensors computed for previous tokens are stored in a cache (KV Cache) to avoid recomputing them. Quantizing theKV Cacheis crucial for reducing memory usage, especially for long sequences.
- Linear Layers: Also known as
- Perplexity (PPL): A common intrinsic evaluation metric for language models. It quantifies how well a probability model predicts a sample. A lower
PPLindicates that the model is better at predicting the next word in a sequence, implying higher quality language generation. - Accuracy: An extrinsic evaluation metric, typically used for classification tasks (like
zero-shot QA). It measures the proportion of correctly predicted instances out of the total instances. Higher accuracy means better performance. - Kronecker Product (): A mathematical operation that combines two matrices of arbitrary size to form a larger matrix. If is an matrix and is a matrix, their Kronecker product is an block matrix.
FlatQuantuses this to construct a large affine transformation matrix from two smaller, lightweight matrices, significantly reducing computational and memory costs. - Singular Value Decomposition (SVD): A powerful matrix factorization technique. Any real matrix can be decomposed into , where and are orthogonal matrices and is a diagonal matrix of singular values.
SVDis numerically stable and can be used to accurately compute the inverse of a matrix (), which is essential forFlatQuant's learnable transformations. - Automatic Mixed Precision (AMP): A training technique that uses both
FP16andFP32precision during model training. It typically performs most operations inFP16for speed and memory efficiency, while keeping certain critical operations (e.g., loss calculation, weight updates) inFP32to maintain numerical stability and prevent gradient underflow/overflow. This allows for faster training with less memory consumption. - CUDA, Triton, CUTLASS, FlashInfer: These are technologies related to GPU programming and high-performance computing for deep learning:
- CUDA: NVIDIA's parallel computing platform and programming model for GPUs.
- Triton (OpenAI Triton): A domain-specific language and compiler for writing highly optimized
GPU kernels. It simplifies the process of achieving high performance for deep learning operations, often outperforming manually optimized CUDA kernels.FlatQuantuses it for kernel fusion. - CUTLASS: A
CUDAtemplate library for highly optimizedGEMM(General Matrix Multiply) operations onNVIDIA GPUs. It provides high-performance routines for various data types, including low-precision integers (INT4).FlatQuantadoptsCUTLASSforINT4 matrix multiplication. - FlashInfer: A kernel library specifically designed for efficient
LLM serving, including optimizedKV cacheoperations and quantization.FlatQuantusesFlashInferforKV cache quantization.
3.2. Previous Works
The paper discusses several categories of prior work related to LLM quantization:
-
General LLM Quantization Methods:
- GPTQ (Frantar et al., 2022): A widely used
post-training quantizationmethod forGenerative Pre-trained Transformers. It performsweight-only quantizationby sequentially quantizing weights layer-by-layer while minimizing themean squared error (MSE)in the output of each layer. It uses aHessian-basedapproach for efficient weight updates.GPTQis often combined with other techniques for better results. - AWQ (Activation-aware Weight Quantization) (Lin et al., 2023): Focuses on the observation that only a small percentage of weights are
salient(critical for performance) and sensitive to quantization.AWQproposes to protect thesesalient weightsfrom quantization error by scaling them, while quantizing the rest aggressively. - QUIK-4B (Ashkboos et al., 2023): An end-to-end
4-bit inferencemethod forgenerative LLMs, aiming for high accuracy with very low precision. - QuIP (Chee et al., 2024) / QuIP# (Tseng et al., 2024): These methods propose
2-bit quantizationwith guarantees, utilizing techniques likeHadamard incoherenceandlattice codebooksto achieve high performance at extremely low bit-widths.
- GPTQ (Frantar et al., 2022): A widely used
-
Pre-Quantization Transformations for Outlier Suppression:
- Per-channel Scaling (e.g., SmoothQuant by Xiao et al., 2023; OmniQuant by Shao et al., 2023; Outlier Suppression+ by Wei et al., 2023): This is a popular technique to mitigate the impact of
outliersinactivations. The core idea is to move thequantization difficultyfromactivationstoweightsby applying a channel-wise scaling factor.- The linear layer operation is .
- With
per-channel scaling, this is transformed to , where is the channel-wise scaling factor. - The inverse scaling factor is applied to
activations, and the scaling factor is applied toweights. - The scaled weights can be pre-computed and stored offline, so only incurs runtime overhead.
SmoothQuantspecifically aims to balance the magnitudes ofactivationsandweightsto make both easier to quantize. The scaling factor for channel is often determined by , where is a hyperparameter.OmniQuant(Shao et al., 2023) further treats both scaling and shifting factors as learnable parameters.
- Hadamard Transformation (e.g., QuaRot by Ashkboos et al., 2024; SpinQuant by Liu et al., 2024c; Training Transformers with 4-bit Integers by Xi et al., 2023): These methods use
Hadamard matricestorotatethe channels of bothactivationsandweights.- A
Hadamard matrixis an orthogonal matrix (i.e., ). - The core idea is that .
- By multiplying
activationswith andweightswith ,outliersare redistributed more evenly across channels, leading to flatter distributions. - Similar to per-channel scaling, the transformed weight can be pre-computed offline. The online operation is .
QuaRotandSpinQuantbuild on this, often combining Hadamard rotations with other techniques likelearnable rotationsand modifications toLayerNormfor efficiency.SpinQuantspecifically useslearned rotationsto alleviate outliers.
- A
- Affine Transformation Quantization (AffineQuant by Ma et al., 2024): This work also explores
affine transformationsto improve quantization. It aims to find optimalinvertible matricesfor each layer to minimize quantization error.FlatQuantcan be seen as building upon this idea but with a strong focus onefficiency(Kronecker product) andflatness.
- Per-channel Scaling (e.g., SmoothQuant by Xiao et al., 2023; OmniQuant by Shao et al., 2023; Outlier Suppression+ by Wei et al., 2023): This is a popular technique to mitigate the impact of
3.3. Technological Evolution
The field of LLM quantization has rapidly evolved to address the growing size and computational demands of LLMs.
-
Early Methods (e.g.,
RTN): Started with simpleRound-To-Nearestquantization, which often led to significant accuracy drops, especially for lower bit-widths. -
Outlier-Aware Techniques: The discovery of
outliersin LLMactivationsbecame a key focus. Methods likeSmoothQuantintroducedper-channel scalingtosmoothdistributions by shifting quantization difficulty fromactivationstoweights. -
Rotation-Based Transformations: Recognizing the limitations of simple scaling,
Hadamard transformationsandorthogonal transformations(e.g.,QuaRot,SpinQuant) emerged. These methodsrotatethe data in the channel dimension to redistribute outliers and achieve betterflatness, often combined withlearnable components. -
Hardware-Aware Optimization: Alongside algorithmic improvements, significant effort has gone into optimizing
quantized inferenceon hardware. This includes designingcustom kernels(e.g., usingCUTLASS,Triton) and specialized libraries (FlashInfer) to makelow-bit operationstruly fast by reducingmemory accessoverheads and maximizingGPU utilization. -
Learnable & Adaptive Transformations: The latest trend, exemplified by
AffineQuantandFlatQuant, involveslearningthe optimal transformations rather than relying on fixed or heuristically derived ones. This allows for greater adaptability to the unique characteristics of eachLLM layer.FlatQuantfits into this evolution by pushing the boundaries oflearnable transformations. It explicitly targetsflatnessas a primary objective forquantization errorreduction and combines this with highlyefficient implementation strategies(Kronecker product, kernel fusion) to ensure practical speedups without sacrificing accuracy. It represents a significant step towards deploying highly accurate, ultra-low-bitLLMsin real-world applications.
3.4. Differentiation Analysis
Compared to the main methods in related work, FlatQuant introduces several core differences and innovations:
-
Explicit Focus on Flatness as Objective: While previous methods (e.g.,
SmoothQuant,Hadamard transformationmethods) implicitly aim for flatter distributions by mitigating outliers,FlatQuantexplicitly identifiesflatnessas a crucial factor for minimizingquantization errorand directly optimizes for it. This objective-driven approach allows for more targeted and effective transformations. -
Learnable Affine Transformations per Layer:
- Unlike
per-channel scaling(e.g.,SmoothQuant), which applies a scalar diagonal transformation,FlatQuantuses fullaffine transformation matrices() that arelearnableandlayer-specific. This offers significantly more expressiveness and adaptability to the unique distribution characteristics of each individual linear layer within theLLM.SmoothQuanttypically uses a fixed formula or a limited set of learnable scalars. - Compared to
Hadamard transformations(e.g.,QuaRot,SpinQuant), which use fixed orthogonal matrices or learned rotations,FlatQuant'saffine transformationsare more general. While Hadamard matrices are orthogonal,FlatQuantlearns a broader class of invertible matrices, providing greater flexibility to shape distributions. Moreover, existing Hadamard-based methods often apply a single transformation across many layers or share transformations, whereasFlatQuantoptimizes them per linear layer. AffineQuantalso usesaffine transformations, butFlatQuantdifferentiates itself through itsefficiency designandkernel fusion.
- Unlike
-
Efficiency through Kronecker Product: A major innovation of
FlatQuantis the use ofKronecker product() to construct the large affine transformation matrix from two much smaller matrices. This dramatically reduces thememory footprintandcomputational overheadof the transformation compared to a full-size matrix . This addresses a key practical limitation of general affine transformations, making them viable for large LLMs. -
Optimized Kernel Fusion:
FlatQuantintegrates theaffine transformationandquantizationoperations into asingle fused kernelusingOpenAI Triton. Thiskernel fusionis crucial for minimizingglobal memory access(a common bottleneck for memory-bound operations) andkernel launch overhead, which translates directly into superiorend-to-end inference speedup. This hardware-aware optimization sets it apart from methods that might perform transformations and quantization as separate, sequential operations. -
Compatibility and Robustness:
FlatQuantis shown to be highly compatible with variousquantization techniques(e.g.,learnable clipping,RTN,GPTQ) and can be applied to different settings (e.g.,weight-only,KV cache quantization). Its ability to achieve high accuracy with simpleRTNquantization reduces calibration complexity and time compared to methods that heavily rely onGPTQfor good performance.In essence,
FlatQuantmoves beyond fixed or simple transformations bylearningmore powerfulaffine transformationsfor each layer, while simultaneously ensuring these transformations arecomputationally and memory-efficientthrough mathematical decomposition (Kronecker product) andhardware-optimized kernel design(kernel fusion). This combined approach enables it to achieve superioraccuracy-speed trade-offsforLLM quantization.
4. Methodology
4.1. Principles
The core principle behind FlatQuant is that flatness in the distributions of weights and activations is paramount for minimizing quantization error in Large Language Models (LLMs). When values are tightly clustered or contain extreme outliers, mapping them to a limited set of equally spaced quantization points (as is common in many quantization schemes) leads to significant information loss and performance degradation. FlatQuant's key insight is to systematically apply fast and learnable affine transformations to each linear layer within the LLM to explicitly enhance this flatness.
The theoretical basis is rooted in the understanding that transformations can reshape data distributions. By learning optimal affine matrices, the method can rotate, scale, and shift the data in a way that spreads out outliers and makes the overall distribution more uniform across the available quantization range. This reduces the mean squared error (MSE) between the original and quantized values. To ensure these powerful transformations do not introduce prohibitive runtime overhead, FlatQuant employs two crucial efficiency mechanisms:
-
Kronecker Product Decomposition: Instead of learning a single large transformation matrix for each high-dimensional layer, it decomposes this matrix into a
Kronecker productof two much smaller,lightweight matrices. This significantly reduces theparameter countandcomputational costof applying the transformation. -
Kernel Fusion: The transformation operations, which are
memory-bound, arefusedtogether with thequantizationsteps into asingle custom kernel. This minimizesglobal memory accessesandkernel launch overhead, ensuring that the algorithmic benefits translate into practical inference speedups.In essence,
FlatQuantseeks to intelligently preprocess the numerical data within anLLMto be maximally amenable to low-bit quantization, balancing expressive power with computational efficiency.
4.2. Core Methodology In-depth (Layer by Layer)
4.2.1. Preliminaries on LLM Quantization
LLM inference generally consists of two stages:
-
Prefill Stage: Processing the input sequence to build a
Key-Value cache (KV Cache)layer by layer. -
Decoding Stage: Autoregressively generating new tokens based on the accumulated
KV Cache.Quantizationis applied to reduce the precision ofweightsandactivationsinlinear layers(which typically perform the operation ), and optionally to theKV Cache.
For -bit weight quantization, the process can be represented as: Here:
- represents the quantized weight matrix.
- is the general -bit quantization function.
- is the
quantization step size(orscale), which determines the interval betweenquantization points. It effectively sets the range of values that can be represented. - is the
projection functionthat maps a real-valued number to the nearest integerquantization point. Forround-to-nearest (RTN)quantization, it simply rounds to the closest integer. - is the set of -bit integer points. For example, for , it's . For simplicity, is used to denote the general quantization function.
As LLMs are known to have persistent outliers in activations, two common pre-quantization transformations are discussed:
-
Per-channel Scaling: This method addresses
outliersininput activationsby applying channel-wise scaling. The original matrix multiplication is reformulated as: $ \mathbf{Y} = (\mathbf{X} \mathrm{diag}(\mathbf{c})^{-1}) \cdot (\mathrm{diag}(\mathbf{c}) \mathbf{W}^{\top}) $ Here:- is the
channel-wise scaling factorvector. - creates a diagonal matrix with elements of on its diagonal.
- The term scales the
activationsby for each channel . - The term scales the
weightsby for each channel . - The
scaling vectoris designed tosmooththeactivationsby considering the magnitudes of bothactivationsandweights, e.g., . Thescaled weightscan bemergedand pre-processed offline, eliminating runtime computation overhead for the weight scaling. - Variants like
Outlier Suppression+(Wei et al., 2023) introduce channel-wise shifting, andOmniQuant(Shao et al., 2023) makes and the shiftlearnable parameters.
- is the
-
Hadamard Transformation: This method uses
Hadamard matricestorotatethe channels of bothactivationsandweights, thereby redistributingoutliers. Due to theorthogonalityof Hadamard matrices (), the following equivalence holds: $ \mathbf{Y} = \mathbf{X} \mathbf{W}^{\top} = (\mathbf{X} \mathbf{H}) (\mathbf{H}^{\top} \mathbf{W}^{\top}) $Activationsare transformed by multiplying with (i.e., ).Weightsare transformed by multiplying with (i.e., ).- The transformed weights can also be pre-processed offline. This rotation aims to more effectively
eliminate outliersby spreading them across all channels.
4.2.2. The Flatness for Quantization
The paper motivates its approach by emphasizing the importance of flatness for quantization. Flat tensors (those with low kurtosis) are intuitively easier to quantize after outlier removal. Figure 1 visually demonstrates this:
As shown in Figure 1, FlatQuant (green lines) consistently achieves flatter distributions for both weights and activations compared to other methods. The original distributions (gray) are very steep and dispersed, indicating severe outliers. Per-channel scaling (orange) flattens activations but can lead to steeper weight envelopes. Hadamard transformation (blue) generally performs better than per-channel scaling but can still be unsatisfactory.
The following figure (Figure 1 from the original paper) shows distributions of weights and inputs from LLaMA-3-8B and LLaMA-3-70B:
该图像是一个图表,展示了 LLaMA-3-8B 和 LLaMA-3-70B 模型中权重和输入的分布。左侧显示的是第 10 层 Transformer 的 (a)和 (b);右侧则是第 30 层 Transformer 的 (c)和 (d)。数据按照信号幅值(Frobenius 范数)降序排列,图中比较了原始数据、每通道缩放、Hadamard 变换和 FlatQuant 方法的效果。
Figure 1: Distributions of weights and inputs from LLaMA-3-8B and LLaMA-3-70B, sorted by the channel magnitude (i.e., the Frobenius norm) in descending order. In a Transformer layer, and denote the weight matrix and input of the output projection layer in the self-attention layer, respectively. and denote the weight and input of the gated linear layer of the feed-forward network, respectively. More visualizations can be found in Appendix D.
Beyond individual layer flatness, FlatQuant also addresses the flatness of the quantization error landscape. Quantization error accumulates and propagates across Transformer layers. Figure 2 illustrates the mean squared error (MSE) landscape. Initially, massive quantization errors occur at pivot tokens (early tokens in a sequence), which are rich in outliers. Neither per-channel scaling nor Hadamard transformation effectively mitigates these initial errors (Figure 2a-2b). In contrast, FlatQuant (Figure 2c) shows significantly lower error at pivot tokens. Furthermore, while quantization error generally increases layer-wise, FlatQuant is shown to be most effective at controlling this error propagation (Figure 2d), followed by Hadamard transformation and then per-channel scaling.
The following figure (Figure 2 from the original paper) shows the mean squared quantization error landscape across Transformer layers and input sequence tokens for LLaMA-3-8B:
Figure 2: The mean squared quantization error (MSE) landscapes across Transformer layers and input sequence tokens for LLaMA-3-8B. The MSE is normalized with respect to the FP16 baseline. FlatQuant shows much lower error at these pivot tokens from Figure 2c. According to Figure 2d, FlatQuant is the best in controlling the error propagation.
4.2.3. Fast and Learnable Affine Transformation
FlatQuant's goal is to find the optimal affine transformation for each linear layer. For a linear layer computing , one ideally seeks an optimal invertible matrix by minimizing the Frobenius norm of the difference between the full-precision output and the quantized transformed output:
Here:
- is the optimal invertible transformation matrix.
- is the original full-precision output of the linear layer.
- is the quantized version of the
activationstransformed by . - is the quantized version of the
weightstransformed by the inverse of . - is the squared
Frobenius norm, which measures themean squared error (MSE). The transformed weights can be pre-computed offline. However, maintaining and applying individual full-size matrices for each layer would be computationally and memory-wise expensive.
Kronecker Product
FlatQuant addresses the cost of full-size affine transformations by using a Kronecker product of two lightweight matrices as an efficient substitute. Specifically, the transformation matrix is constructed as , where and are invertible matrices of smaller sizes, such that the original dimension .
Leveraging the vectorization trick of the Kronecker product (i.e., for some ), the matrix multiplication in Equation 2 can be rewritten as:
Here:
- and are the
activationsandweightsreshaped accordingly from their original matrix forms to 3D tensors compatible with theKronecker producttransformation. - denotes the reduction (matrix multiplication) over the -th axis (dimension). This implies the operations are applied across the reshaped dimensions of and .
- The terms and represent the transformations applied to the reshaped activations and weights, respectively.
- The final indicates a transpose operation, restoring the matrix multiplication structure.
This design significantly saves memory (up to times) and computation (up to times) when . For instance, for , the optimal configuration is . These
affine transformationsare verylightweight, adding minimal FLOPs and memory.
Per-channel Scaling
To further enhance outlier balancing, FlatQuant explicitly incorporates a learnable scaling vector before the pre-quantization transformation. This vector can be merged with preceding layer normalization or linear layers, incurring no additional inference overhead, similar to SmoothQuant.
Learnable Clipping Thresholds
FlatQuant also includes learnable clipping thresholds after sigmoid functions for both weight and activation quantization in each linear layer, as well as for the KV cache. While grid search is common for finding these thresholds, FlatQuant learns them, leading to better results. These parameters are layer-specific and optimized jointly with and .
The Training Objective
FlatQuant uses a post-training quantization approach, minimizing mean squared error (MSE) sequentially for each Transformer block over a small calibration dataset (e.g., 128 sentences). The objective for the -th Transformer block is:
Here:
- denotes the original full-precision output of the -th
Transformer block. - denotes the output of the -th
quantized Transformer blockwith learnable parameters . - represents all
learnable parameterswithin that block: theaffine transformation matrices(implicitly ), thescaling vector, and theactivation() andweight()clipping thresholds. The optimization is performed usingsingular value decomposition (SVD)for accurate and efficientmatrix inversionandautomatic mixed precision (AMP)for training stability and efficiency (details in Appendix B.1).
The overall framework of FlatQuant is depicted in Figure 3. It shows the integration of learnable affine transformations, per-channel scaling, and learnable clipping thresholds within the Transformer architecture.
The following figure (Figure 3 from the original paper) illustrates the overall framework of FlatQuant:
该图像是图表,展示了LLama-3-70B模型中多个Transformer层的参数分布。图中包括四个子图:(a)和(b)分别展示了第10个Transformer层中和的幅度随通道变化的情况;(c)和(d)展示了第30个Transformer层中和的幅度变化情况。各子图的横轴为通道数,纵轴为幅度,显示了不同参数的分布特点。
Figure 3: The overall framework of FlatQuant. (a) Necessary notations for FlatQuant; (b) The integration of FlatQuant with a conventional LLaMA layer, where merged parameters are grouped in red; (c) Online transformation and down-projection layer, where the scaling vector over is merged to in practice.
4.2.4. Integration with the Transformer Architecture
FlatQuant is integrated into a Transformer block (e.g., LLaMA-like architecture) as follows:
- Quantized Layers: All
linear layersutilizelow-bit matrix multiplications. - FP16 Layers:
Layer normalization layers,pre-quantization transformations,RoPE embeddings, andattention scoresremain inFP16for precision.
Self-Attention
The self-attention module incorporates four transformation matrices:
- : Applied to flatten the
input activationfor thequery,key, andvalue projections. - : Smooths the
input activationfor theoutput projection. - : Transforms the
key cachehead by head. - : Transforms the
value cachehead by head. Notably, only and aredecomposedusing theKronecker productbecause thehead sizeinper-head quantizationforKV cacheis already small, making full transformations for and computationally cheap. is furtherfusedwith to reduce overhead, inspired byQuaRot.
Feed-forward Network (FFN)
The FFN utilizes two transformation matrices:
- : Applied to flatten the input of the
FFNafterlayer normalization. - : Flattens the input for the
down-projection layer. Both and aredecomposedto minimizeinference overhead. Theper-channel scalingof ismergedinto theweightof theup-projection layerto avoid additional computational cost.
Layer Normalization
Unlike some methods (e.g., QuaRot, SpinQuant) that modify LayerNorm to RMSNorm and merge transformations into preceding layers, FlatQuant preserves the original LayerNorm. This allows for distinct fast and learnable affine transformations to be applied after LayerNorm for different layers, increasing the model's expressiveness and adaptive capacity.
4.2.5. Efficient Kernel Design
FlatQuant implements an efficient kernel that fuses the affine transformations and quantization into a single operation using OpenAI Triton. This design is motivated by:
-
Memory-Bound Operations: The
Kronecker producttransformations () arelow computational intensityoperations, making themmemory-bound. Similarly,quantizationitself is alsomemory-bound. -
Minimizing Overhead: Fusing these operations into a single kernel eliminates redundant
global memory accessesfor intermediate results and reduceskernel launch overhead, which is critical for achieving substantial speedups.The fusion process involves:
-
Loading the entire
P1andP2matrices intoSRAM(on-chip memory, much faster than global memory). -
Each
thread blockprocesses a smalltiling blockfrom the reshapedactivations. -
It performs the
matrix multiplicationand quantizes the resultson the fly. -
All intermediate results are kept in
SRAMbefore the final quantized output is written back toglobal memory.For
INT4 matrix multiplication,FlatQuantfollowsQuaRotby adopting theCUTLASS kernel. ForKV cache quantization, it usesFlashInfer. Details on handlingcorner caseswhereSRAMmight not be large enough for very large tensors are provided in Appendix B.3, involving strategies like tiling non-reduction dimensions.
5. Experimental Setup
5.1. Datasets
The experiments primarily evaluate FlatQuant on various Large Language Models (LLMs) from the LLaMA-2 (Touvron et al., 2023) and LLaMA-3 (Dubey et al., 2024) series. Additional results on Qwen (Yang et al., 2024) and DeepSeek families are provided in Appendix C.1.
The datasets used for evaluation are:
- Language Generation Tasks:
WikiText-2(Merity et al., 2016): A dataset of over 100 million tokens extracted from "Good" and "Featured" articles on Wikipedia. It is commonly used for evaluating language models, especially forperplexity (PPL)due to its coherent and diverse text.C4(Colossal Clean Crawled Corpus) (Raffel et al., 2020): A massive, cleaned corpus of text derived from web pages, comprising hundreds of gigabytes of text. It's widely used for training and evaluating large language models, providing a broad and diverse range of text styles and topics.
- Commonsense Reasoning (Zero-shot QA) Tasks: Six popular benchmarks are used:
-
ARC-ChallengeandARC-Easy(Clark et al., 2018): Datasets designed to test a machine's ability to answer questions requiring commonsense reasoning.ARC-Challengecontains questions that are harder for algorithms to answer. -
HellaSwag(Zellers et al., 2019): A dataset for commonsense natural language inference, where models must choose the most plausible ending to a given sentence. It's designed to be difficult for models that rely purely on statistical cues. -
LAMBADA(Paperno et al., 2016): A dataset forword predictionthat specifically requires broad discourse context. Models must predict the last word of a sentence, where the necessary context often spans multiple sentences. -
PIQA(Bisk et al., 2020): Physical Interaction QA, a commonsense reasoning benchmark focused on questions about physical commonsense, requiring models to choose the correct plan to achieve a goal. -
WinoGrande(Sakaguchi et al., 2021): An adversarialWinograd Schema Challengeat scale, designed to test commonsense reasoning by resolving ambiguous pronouns in sentences.These datasets are chosen because they are standard benchmarks in
LLMresearch, providing a comprehensive assessment of both language generation quality and reasoning abilities, which are critical for validating the effectiveness ofquantization methods.
-
5.2. Evaluation Metrics
For every evaluation metric mentioned in the paper, here is a complete explanation:
5.2.1. Perplexity (PPL)
- Conceptual Definition:
Perplexity (PPL)is a measure of how well a probability model predicts a sample. In the context of language models, it quantifies how "surprised" the model is by a given text sequence. A lowerPPLindicates that the model assigns a higher probability to the observed sequence, meaning it predicts the text more accurately and is a better language model. It is the inverse probability of the test set, normalized by the number of words. - Mathematical Formula: $ PPL(W) = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_{1...i-1})\right) $
- Symbol Explanation:
- : The sequence of words (or tokens) in the test set.
- : The total number of words (or tokens) in the sequence .
- : The probability of the -th word given all preceding words , as predicted by the language model.
- : The natural logarithm.
- : The exponential function (inverse of natural logarithm).
5.2.2. Accuracy
- Conceptual Definition:
Accuracyis a common metric for classification tasks, such aszero-shot question answering (QA). It represents the proportion of correct predictions made by the model out of the total number of predictions. A higher accuracy percentage indicates better performance. - Mathematical Formula: $ Accuracy = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} $
- Symbol Explanation:
Number of Correct Predictions: The count of instances where the model's output (e.g., the chosen answer in a QA task) perfectly matches the true, ground-truth label.Total Number of Predictions: The total number of instances (e.g., questions) in the evaluation dataset.
5.3. Baselines
FlatQuant is compared against several popular INT4 post-training quantization (PTQ) methods, which are representative of the state-of-the-art in the field:
-
SmoothQuant(Xiao et al., 2023): A widely recognized method that usesper-channel scalingtosmoothactivationsby transferringquantization difficultytoweights. -
OmniQuant(Shao et al., 2023): An extension that makes scaling and shifting factorslearnableparameters foromnidirectional calibration. -
AffineQuant(Ma et al., 2024): A recent method that also employsaffine transformationsfor quantization, serving as a direct comparison point forFlatQuant's transformation approach. -
QUIK-4B(Ashkboos et al., 2023): A method focused onend-to-end 4-bit inferenceforgenerative LLMs. -
QuaRot(Ashkboos et al., 2024): A very recentstate-of-the-artmethod that utilizesHadamard transformations(rotations) to achieveoutlier-free 4-bit inference. -
SpinQuant(Liu et al., 2024c): Another recentstate-of-the-artmethod that useslearned rotationsforLLM quantization.These baselines are chosen because they represent the current leading approaches in
LLM quantization, covering different strategies likescaling,fixed rotations,learned rotations, andgeneral affine transformations, providing a robust comparison forFlatQuant's performance.
5.4. Implementation Details
The implementation specifics are as follows:
- Frameworks:
Huggingface Transformers(Wolf, 2019) andPyTorch(Paszke et al., 2019). - Optimizer:
AdamWoptimizer. - Learning Rates: Initial learning rate of
5e-3for the main learnable parameters, and5e-2forclipping thresholds. - Learning Rate Schedule:
Cosine annealing learning rate decay. - Calibration:
15 epochsof training.Calibration set:128 sentencessampled fromWikiText-2.Sequence length: Each sentence is processed with2048 tokens.Batch size:4.
- Computational Resources: Calibration for
LLaMA-3-8Brequires approximately26GBofGPU memoryand takes about0.9 hourson a singleGPU. - Initialization:
FlatQuantis robust to initialization, usingrandom affine transformation matricesas a starting point. - Matrix Inversion and Training:
Singular Value Decomposition (SVD)is used for efficient and accuratematrix inversion(for ), combined withAutomatic Mixed Precision (AMP)to reduce training time and memory usage while maintaining accuracy. (Appendix B.1).
5.5. Quantization Scheme
- Weights and Activations:
Per-channelsymmetric quantization forweightsandper-tokensymmetric quantization foractivations. - Weight Quantizer: For fair comparison,
FlatQuantis evaluated with bothround-to-nearest (RTN)andGPTQas theweight quantizer. WhenGPTQis used, it shares the same calibration data asFlatQuantfor both its closed-form weight updates and training. The paper notes thatFlatQuantwithRTNis often sufficient and competitive. - KV Cache Quantization:
Group-wise asymmetric quantizationis applied to theKV cache, with agroup sizeof128. This choice matches thehead dimensionofLLaMAmodels and leverages thememory-bound characteristicsofself-attention.
6. Results & Analysis
6.1. Main Results
The experiments demonstrate FlatQuant's superior performance in both accuracy and inference latency across various LLM families and tasks.
6.1.1. Results on Language Generation Tasks
The following are the results from Table 1 of the original paper:
| Method | W Quantizer | WikiText-2 | C4 | ||||||||
| 2-7B | 2-13B | 2-70B | 3-8B | 3-70B | 2-7B | 2-13B | 2-70B | 3-8B | 3-70B | ||
| FP16 | - | 5.47 | 4.88 | 3.32 | 6.14 | 2.86 | 7.26 | 6.73 | 5.71 | 9.45 | 7.17 |
| SmoothQuant | RTN | 83.12 | 35.88 | 26.01 | 210.19 | 9.60 | 77.27 | 43.19 | 34.61 | 187.93 | 16.90 |
| OmniQuant | RTN | 14.74 | 12.28 | - | - | : | 21.40 | 16.24 | - | - | - |
| AffineQuant | RTN | 12.69 | 11.45 | - | - | 15.76 | 13.97 | - | - | - | |
| QuaRot | RTN | 8.56 | 6.10 | 4.14 | 10.60 | 55.44 | 11.86 | 8.67 | 6.42 | 17.19 | 79.48 |
| SpinQuant | RTN | 6.14 | 5.44 | 3.82 | 7.96 | 7.58 | 9.19 | 8.11 | 6.26 | 13.45 | 15.39 |
| FlatQuanT | RTN | 5.79 | 5.12 | 3.55 | 6.98 | 3.78 | 7.79 | 7.09 | 5.91 | 11.13 | 7.86 |
| QUIK-4B | GPTQ | 8.87 | 7.78 | 6.91 | - | - | - | - | - | - | - |
| QuaRot | GPTQ | 6.10 | 5.40 | 3.79 | 8.16 | 6.60 | 8.32 | 7.54 | 6.12 | 13.38 | 12.87 |
| SpinQuant | GPTQ | 5.96 | 5.24 | 3.70 | 7.39 | 6.21 | 8.28 | 7.48 | 6.07 | 12.19 | 12.82 |
| FLATQUaNT | GPTQQ | 5.78 | 5.11 | 3.54 | 6.90 | 3.77 | 7.86 | 7.11 | 5.92 | 11.21 | 7.93 |
Table 1: WikiText-2 and C4 perplexity of 4-bit weight & activation quantized LLaMA models.
Table 1 presents the Perplexity (PPL) results for 4-bit weight & activation quantized LLaMA models on WikiText-2 and C4 datasets. Lower PPL indicates better performance.
-
Superiority of FlatQuant (RTN):
FlatQuantusinground-to-nearest (RTN)as theweight quantizerconsistently outperforms all previousstate-of-the-artmethods across allLLaMAmodels and both datasets. For instance, onLLaMA-2-70B,FlatQuant-RTNachieves a PPL of3.55onWikiText-2and5.91onC4, which is very close to theFP16baseline (3.32and5.71respectively). This implies a minimal performance degradation. -
LLaMA-3 Models:
FlatQuantshows strong performance on the newerLLaMA-3models. ForLLaMA-3-8B,FlatQuant-RTNachieves6.98PPL onWikiText-2, significantly better thanSpinQuant-RTN's7.96andQuaRot-RTN's10.60. -
GPTQ vs. RTN: Remarkably,
FlatQuant-RTN's performance is often comparable to, and sometimes even surpasses,GPTQ-based methods from other baselines. For example,FlatQuant-RTNonLLaMA-3-70BonWikiText-2(PPL3.78) is better thanQuaRot-GPTQ(PPL6.60) andSpinQuant-GPTQ(PPL6.21). This highlights thatFlatQuant's learnable transformations are highly effective even with a simplerquantizer, reducing calibration time.FlatQuant-GPTQprovides a slight additional boost, but the difference is small, reinforcing the strength of its core method.These results strongly validate that
FlatQuant's approach of enhancingflatnessthroughlearnable transformationseffectively mitigatesquantization error, setting a new benchmark for low-bitLLMquantization.
6.1.2. Results on Zero-shot QA Tasks
The following are the results from Table 2 of the original paper:
| Model | Method | W Quantizer | ARC-C | ARC-E | HellaSwag | LAMBADA | PIQA | Winogrande | Avg |
| 2-7B | FP16 | - | 46.16 | 74.54 | 75.98 | 73.92 | 79.05 | 69.06 | 69.79 |
| QuaRot | RTN | 36.60 | 61.41 | 65.07 | 48.06 | 72.20 | 63.06 | 57.73 | |
| SpinQuant | RTN | 39.42 | 65.32 | 71.45 | 66.16 | 75.30 | 63.46 | 63.52 | |
| FLATQUaNT | RTN | 43.26 | 72.05 | 73.64 | 72.04 | 77.26 | 69.53 | 67.96 | |
| QuaRot | GPTQ | 42.32 | 68.35 | 72.53 | 65.40 | 76.33 | 65.11 | 65.01 | |
| SpinqQuant | GPTQ | 41.72 | 69.28 | 72.90 | 71.28 | 76.17 | 66.06 | 66.23 | |
| 2-13B | FlatQuaNt | GPTQ | 43.00 | 71.21 | 73.31 | 72.06 | 77.53 | 67.72 | 67.47 |
| FP16 | - | 49.15 | 77.44 | 79.39 | 76.73 | 80.47 | 72.14 | 72.55 | |
| QuaRot SpinQquant | RTN | 42.83 | 69.95 | 73.54 | 65.62 | 77.69 | 67.88 | 66.25 | |
| FLaTQuaNT | RTN RTN | 43.69 48.04 | 72.43 76.64 | 75.52 77.59 | 72.42 76.60 | 78.40 79.38 | 68.90 70.24 | 68.56 71.42 | |
| QuaRot | GPTQ | 69.01 | 79.05 | 70.64 | |||||
| SinqQuant | GPTQ | 45.48 49.15 | 73.27 77.19 | 76.03 76.86 | 73.86 | 78.67 | 69.85 | 68.91 70.93 | |
| FLaTQuaNT | GPTQ | 48.38 | 76.94 | 77.88 | 76.40 | 79.65 | 70.56 | 71.64 | |
| 2-70B | FP16 | - | 57.17 | 81.02 | 83.81 | 79.60 | 82.70 | 77.98 | 77.05 |
| QuaRot | RTN | 52.22 | 76.60 | 79.96 | 74.61 | 81.12 | 76.32 | 73.47 | |
| SpinQuant | RTN | 55.03 | 79.17 | 81.76 | 78.87 | 81.45 | 74.27 | 75.09 | |
| FLATQUuANT | RTN | 56.14 | 80.30 | 83.01 | 79.60 | 82.75 | 77.90 | 76.62 | |
| QuaRot | GPTQ | 55.46 | 79.76 | 81.58 | 79.35 | 81.83 | 76.09 | 75.68 | |
| SpinqQuant | GPTQ | 55.38 | 79.04 | 82.57 | 78.75 | 82.37 | 78.22 | 76.06 | |
| 3-8B | FLAtQUaNT | GPTQ | 56.40 | 80.09 | 82.91 | 80.01 | 82.92 | 76.87 | 76.53 |
| FP16 | - | 53.50 | 77.57 | 79.12 | 75.51 | 80.74 | 72.93 | 73.23 | |
| QuaRot | RTN | 38.65 | 66.54 | 68.82 | 57.20 | 71.82 | 65.04 | 61.34 | |
| SpinQuant FlaTQuaNt | RTN | 45.73 | 71.38 | 74.07 | 67.67 | 76.66 | 66.38 | 66.98 | |
| RTN | 50.00 | 75.80 | 76.80 | 72.91 | 79.16 | 72.69 | 71.23 | ||
| QuaRot SinqQuant | GPTQ | 45.73 | 70.83 | 72.97 | 62.70 | 75.35 | 67.17 | 65.79 | |
| FLaTQuaNT | GPQ GPTQ | 47.27 | 74.20 | 74.55 | 70.29 | 77.37 | 68.51 | 68.70 | |
| 3-70B | FP16 | 50.51 | 75.88 | 76.49 | 73.20 | 79.00 | 72.93 | 71.33 | |
| - | 64.25 | 85.94 | 84.93 | 79.37 | 84.44 | 80.74 | 79.95 | ||
| QuaRot SpinQuant | RTN | 22.18 | 34.30 | 32.15 | 13.35 | 57.67 | 52.49 | 35.36 | |
| FLAATQUANT | RTN RTN | 44.03 | 69.07 | 74.57 | 63.34 | 76.99 | 65.98 80.03 | 65.66 | |
| QuaRot | 62.12 | 84.97 | 83.95 | 78.73 | 84.28 | 79.01 | |||
| SpinQuant | GPTQ GPTQ | 49.49 | 74.37 | 77.22 | 71.69 | 78.89 | 71.03 | 70.45 | |
| 51.96 | 77.40 | 77.29 | 71.90 | 79.33 | 72.06 | 71.66 | |||
| FlatQuaNt | |||||||||
| GPTQ | 61.95 | 84.47 | 83.87 | 77.99 | 83.95 | 79.24 | 78.58 | ||
Table 2: Zero-shot QA task results of 4-bit weight & activation quantized LLaMA models.
Table 2 shows the zero-shot QA task results (accuracy) for 4-bit weight & activation quantized LLaMA models across six benchmarks. Higher accuracy indicates better performance.
-
Significant Accuracy Gains:
FlatQuantconsistently achieves higher average accuracy compared to otherquantization methods. For the challengingLLaMA-3models,FlatQuantsignificantly narrows the performance gap to theFP16baseline. ForLLaMA-3-8B, it achieves an average accuracy of71.23%(RTN) with an accuracy loss of only2.00%compared toFP16(73.23%). ForLLaMA-3-70B,FlatQuant-RTNachieves79.01%(note: there is a discrepancy in the table forLLaMA-3-70B FP16 avgwhere71.33and79.95are both present;FlatQuant-RTNis79.01which is a0.94%drop from79.95). -
Robustness of FlatQuant-RTN: Similar to the PPL results,
FlatQuant-RTNoften performs comparably to, or even better than, baselines usingGPTQ. ForLLaMA-3-8B,FlatQuant-RTN(71.23%) outperformsQuaRot-GPTQ(65.79%) andSpinQuant-GPTQ(68.70%). This further underscoresFlatQuant's ability to achieve high performance without the additional overhead ofGPTQ's weight updates. -
Challenging LLaMA-3 Quantization: The paper notes that
LLaMA-3models are particularly challenging for quantization, which is evident from the larger performance drops for other methods, especiallyQuaRot-RTNonLLaMA-3-70B(average35.36%vsFP1679.95%).FlatQuant's ability to maintain high accuracy here is a strong indicator of its effectiveness.The consistent outperformance on both language generation (PPL) and zero-shot QA (accuracy) tasks, particularly for the latest
LLaMA-3models and with simpleRTNquantization, establishesFlatQuantas a newstate-of-the-art.
6.1.3. Additional Experiments (Appendix C)
Results on Other LLM Architectures (Appendix C.1)
- LLaMA-3.1-8B-Instruct (Table 7):
FlatQuantalso outperformsQuaRotonLLaMA-3.1-8B-Instructfor bothWikiText-2PPL andC4PPL, and significantly higher average accuracy onzero-shot QAtasks. For example,FlatQuantachieves an average QA accuracy of72.03%compared toQuaRot's66.84%(FP16 is73.69%). - Qwen-2.5-Instruct (Table 8):
FlatQuantdemonstrates competitive performance onQwen-2.5-Instructmodels. For theQwen-2.5-32Bmodel,FlatQuant-RTNachieves an average QA score of74.89%, which is only a0.21%drop from theFP16baseline (75.10%) and higher thanQuaRot-GPTQ(72.25%). - DeepSeek V3-Base and DeepSeek R1 (Table 9):
FlatQuant-W4A4shows strong results on large-scaleMixture-of-Experts (MoE)models likeDeepSeek V3-BaseandDeepSeek R1, demonstrating its applicability beyond standard dense LLMs. ForDeepSeek V3-Base,FlatQuant-W4A4achieves89.59C-Eval and86.32MMLU, very close toFP8(90.10and87.10respectively).
Results on MT-Bench (Appendix C.2)
- LLaMA-3.1-8B-Instruct (Table 10): On
MT-Bench,FlatQuant(average6.94) significantly outperformsQuaRot(5.99), narrowing the gap toFP16(7.60). In some categories likeMath,FlatQuant(7.20) even surpassesFP16(7.00).
Extension to More Quantization Settings (Appendix C.3)
-
Weight-Only Quantization (Table 11):
FlatQuant-RTNshows competitive performance inweight-onlysettings (W4A16,W3A16). ForW4A16onWikiText-2,FlatQuant-RTNachieves6.54PPL, comparable toGPTQ-g128(6.50) andQuIP(6.50), and significantly better than plainGPTQ(7.00). -
KV Cache Quantization (Table 12, 13):
FlatQuanteffectively quantizes theKV cacheto very low bit-widths. ForLLaMA-3-8BwithK4V4(4-bit Key, 4-bit Value) quantization,FlatQuantachieves aWikiText-2PPL of6.20and average QA accuracy of73.12%, very close toFP16. Even atK2V2(2-bit Key, 2-bit Value),FlatQuantmaintains significantly better performance thanQuaRot(Table 13). ForLLaMA-2-7B K2V2,FlatQuanthas a PPL of6.66compared toQuaRot's9.23. -
Extreme Low Bit Quantization (Table 14): For
W3A3KV3quantization onLLaMA-3-8B,FlatQuant(10.82WikiText-2 PPL,58.45%Avg QA) vastly outperformsQuaRot(686.54WikiText-2 PPL,30.33%Avg QA), demonstrating its robustness in highly aggressive quantization scenarios. -
Flexible Quantization Settings (Table 15): The learnable transformations in
FlatQuantare flexible. A single set of transformation matrices can be used for different quantization settings (e.g.,W4,A4,KV4independently), maintaining high accuracy.Overall, the results consistently demonstrate
FlatQuant's ability to achievestate-of-the-artaccuracy, even with simplerRTNquantization, across a wide range of LLMs, tasks, and quantization settings.
6.2. Inference Latency
All inference latency experiments were conducted on an RTX3090 GPU.
6.2.1. End-to-end Speedup
The following figure (Figure 4 from the original paper) shows the prefill and decoding speedup of FlatQuant across different batch sizes:
Figure 4: Prefill and decoding speedup of FlatQuant for LLaMA-2-7B on an RTX3090 GPU. We evaluate prefill on a sequence length of 2048 and decoding on 256 tokens. FlatQuant with kernel fusion (green line) achieves up to 2.3x prefill speedup and 1.76x decoding speedup compared to FP16.
Figure 4 illustrates the prefill and decoding speedup of FlatQuant compared to FP16, INT4, and QuaRot across various batch sizes (up to 64).
- Kernel Fusion Impact: Even without
kernel fusion,FlatQuantachieves comparable speedup toQuaRot, validating the efficiency of theKronecker productapproach. - Superior Speedup with Kernel Fusion: With
kernel fusion,FlatQuantsignificantly outperformsQuaRot. It achieves up to2.30xprefill speedup and1.76xdecoding speedup under a batch size of64(for prefill and decoding respectively), which is faster thanQuaRot. While there's a minor gap compared tovanilla INT4 quantization(which lacks the complex transformations),FlatQuant's speedup is substantial given its superior accuracy. This makes it highly practical for deployingINT4 LLMs.
6.2.2. Kronecker Product: Sizes and Perplexities
The following figure (Figure 5 from the original paper) examines the impact of different decomposed matrix sizes in Equation 3 on model performance and speedup:
该图像是图表,展示了在 LLaMA-2-7B 模型上,不同降维矩阵大小对预填充速度和 WikiText2 PPL 结果的影响。x 轴为降维大小,y 轴分别表示速度加速比和 PPL 值,速度加速比用红色星形标记,PPL 值用蓝色圆点标记。
Figure 5: Prefill speedup and WikiText2 PPL results of different decomposed matrix sizes on LLaMA-2-7B model. We decompose the hidden dimension 4096 into and range from 1 to 2048, where amounts to maintaining a full-size transformation matrix. More details can be found in Appendix C.6.
Figure 5 shows the impact of different decomposed matrix sizes (, where ) for the Kronecker product on prefill speedup and WikiText-2 PPL for the LLaMA-2-7B model.
-
Optimal Speedup: The
speedup peakswhen and are of approximately equal size (i.e., ). This aligns with the theoretical analysis in Section 3.1, which predicts optimal computation saving under this condition. -
Impact on PPL: Importantly, variations in decomposed matrix sizes have
limited impact on perplexity, indicating the robustness ofFlatQuant's transformation capability regardless of the exact decomposition, as long as the total transformation is effective. -
Performance Drop: When (the second dimension of the decomposition) exceeds
64, the speedup decreases. This is attributed toirregular memory access patternsforactivationswhen the dimensions become highly unbalanced.These results validate
FlatQuant's effectiveness in minimizinginference overheadthrough theKronecker productwhile preservingquantization accuracy.
6.2.3. Overhead of Each Online Transformation
The following figure (Figure 6 from the original paper) investigates the impact of the five online transformations in FlatQuant on the overall speedup:
Figure 6: Prefill speedup of LLaMA-2-7B on a sequence length of 2048 under a batch size of 64 after applying different online transformations. We incorporate different online transformations sequentially to gauge their impact on the final speedup. Each point on the axis indicates adding a new online transformation.
Figure 6 analyzes the individual contributions of the five online transformations () to the overall speedup in FlatQuant.
-
Minimal Overall Slowdown: Even with all five
per-layer transformations,FlatQuantintroduces a minimal0.07xend-to-end slowdown (when compared to naiveINT4), significantly outperformingQuaRot's0.26xslowdown with only threeHadamard transformations. -
Specific Transformation Impacts:
- (down-projection layer) causes the largest individual slowdown for
FlatQuant(0.04x), due to the largeFFN intermediate sizes. This is still much smaller thanQuaRot's corresponding slowdown (0.17x). - (output projection) results in a
0.01xslowdown, again less thanQuaRot's0.1x. - The remaining transformations ( for query/key/value projections and for FFN input) have an
insignificant impact(less than0.01x).
- (down-projection layer) causes the largest individual slowdown for
-
Efficiency without Kernel Fusion: Even without
kernel fusion, the additional transformations inFlatQuantmaintain competitive performance relative toQuaRot, primarily due to the efficiency gained from theKronecker productdecomposition.This detailed analysis confirms that
FlatQuant's transformations arelightweightand efficiently integrated, contributing minimal overhead while delivering significant accuracy benefits.
6.3. Ablation Studies / Parameter Analysis
6.3.1. Ablation Study (Main Components)
The following are the results from Table 3 of the original paper:
| LT | PS | LCT | WikiText-2 | C4 | Avg |
| 1266.60 | 936.41 | 30.99 | |||
| ✓ | 8.50 | 13.51 | 66.82 | ||
| ✓ | ✓ | 7.95 | 12.74 | 67.08 | |
| ✓ | ✓ | 7.11 | 11.47 | 70.72 | |
| ✓ | ✓ | ✓ | 6.98 | 11.13 | 71.23 |
Table 3: Ablation study of FLATQUANT's main components on LLaMA-3-8B.
Table 3 presents an ablation study of FlatQuant's main components on the LLaMA-3-8B model, starting from a round-to-nearest (RTN) baseline (no LT, PS, LCT).
-
Learnable Transformation (LT): Enabling
Learnable Transformation (LT)alone drastically improves performance, reducingWikiText-2 PPLfrom1266.60to8.50andC4 PPLfrom936.41to13.51. This component brings the average QA accuracy from30.99%to66.82%. This demonstrates thatLTis the most critical component, capable of adaptively flattening distributions and significantly enhancing model accuracy. -
Per-channel Scaling (PS): Adding
Per-channel Scaling (PS)on top ofLTfurther improves performance, loweringWikiText-2 PPLto7.95andC4 PPLto12.74, with a slight increase in average QA accuracy to67.08%. This indicatesPSplays a complementary role in balancingoutliers. -
Learnable Clipping Thresholds (LCT): Including
Learnable Clipping Thresholds (LCT)(either with justLTor with bothLTandPS) yields substantial improvements. With ,WikiText-2 PPLdrops to7.11andC4 PPLto11.47, with average QA accuracy reaching70.72%. -
Full FlatQuant: The combination of all three components () achieves the best results, reaching
6.98WikiText-2 PPL,11.13C4 PPL, and71.23%average QA accuracy.This ablation study clearly demonstrates the necessity and effectiveness of each component (
Learnable Transformation,Per-channel Scaling, andLearnable Clipping Thresholds) in contributing toFlatQuant's overall superior performance by collectively enhancingflatnessand mitigatingquantization error.
6.3.2. FlatQuant Leads to Flatness
The paper quantitatively evaluates flatness by analyzing channel-wise magnitude distributions. Each distribution is represented as a 1D vector . Flatness is measured by the Euclidean distance between the observed distribution and an idealized perfectly flat distribution . The idealized flat distribution is defined such that all channels have equal magnitudes and the same norm as , i.e., , where is the number of channels and is an -dimensional vector of ones. A smaller Euclidean distance indicates greater flatness.
The following figure (Figure 7 from the original paper) visualizes the evolution of flatness and the training objective across different Transformer blocks of LLaMA-3-8B during training:
该图像是一个示意图,展示了不同量化方法的效果,包括(a) 按通道缩放,(b) 哈达玛变换,(c) FlatQuant 方法,以及(d) 叠加视图。每个子图通过三维展示描述了在不同层和 token 数量下的分布情况。
Figure 7: Flatness and mean squared quantization error (MSE) for different Transformer blocks of LLaMA-3-8B during FlatQuant's training process. The metric of flatness is calculated as the sum of Euclidean distances for all weights and activations within a Transformer block.
Figure 7 shows that as the training progresses and the training loss (Equation 4) decreases, the channel distributions become increasingly flat. This direct correlation indicates that FlatQuant successfully learns transformations that yield flatter distributions, which in turn contributes to smaller quantization error. This empirically validates the core hypothesis that flatness matters for LLM quantization.
6.3.3. Additional Ablation Study and Discussions (Appendix C.4 and C.5)
- Detailed Ablation Study (Table 16): A more comprehensive ablation confirms the individual and combined effectiveness of
LT,PS, andLCTon all individualQA tasks, reinforcing their necessity. - Impact of Calibration Data (Table 17):
FlatQuantshows robustness to the choice ofcalibration dataset. When calibrated onWikiText2,C4, orPile,FlatQuantmaintains stable performance, with only minor variations in PPL and QA accuracy, indicating good generalization. - Effect of Learnable Clipping Thresholds (Table 18):
Learnable Clipping Thresholds (LCT)are most effective when appliedaftertheaffine transformation. ApplyingLCT before Transformationyields worse results (7.37WikiText-2 PPL) thanLCT after Transformation(6.98PPL), and evenQuaRot-style Fixed Threshold(7.25PPL) is not as effective. This emphasizes the optimal placement and learning of clipping. - Mixed-Precision Quantization (Table 19):
FlatQuantcan be combined withmixed-precision schemes. Selectively using8-bitquantization for certain sensitive layers (e.g.,down_projlayers, orTop5most sensitive layers) can further improve accuracy with minimal impact on speed, indicating flexibility for deployment scenarios where slightly higher precision is acceptable for critical components.
6.4. Inference Memory and Latency (Appendix C.7 and C.8)
- Inference Memory Consumption (Table 20):
FlatQuantmaintains the memory efficiency ofINT4 quantization. ForLLaMA-2-7B, it achieves a consistent memory reduction of over3.3xcompared toFP16for various sequence lengths (batch size 1), with negligible additional memory overhead from its transformations. - Additional Latency Analysis (Table 21, 22):
-
Prefill Speedup (Table 21, Figure 9):
FlatQuantachieves consistentprefill speedups(e.g.,2.12xat 2048 length,1.80xat 16384 length for batch size 1) comparable toINT4and outperformingQuaRot. -
Decoding Speedup (Table 22, Figure 10): For
decoding(batch size 64),FlatQuantconsistently surpassesQuaRotacross allKV cache lengthsand closely approaches the efficiency ofINT4 quantization.The following figure (Figure 9 from the original paper) shows the prefill speedup of LLaMA-2-7B on a sequence length of 2048:
Figure 9: Prefill speedup of LLaMA-2-7B on a sequence length of 2048. It shows the speedup of INT4, QuaRot, FlatQuant without kernel fusion, and FlatQuant with kernel fusion. FlatQuantwith kernel fusion achieves the highest speedup, close toINT4.
-
The following figure (Figure 10 from the original paper) shows the decoding speedup on LLaMA-2-7B model:
Figure 10: Decoding speedup on LLaMA-2-7B model. We decode 256 tokens after the prefill sequence length 2048. FlatQuant with kernel fusion (green line) shows excellent performance, surpassing QuaRot and approaching INT4.
These results confirm that FlatQuant provides substantial speedup gains across various generation scenarios, including both short and long contexts, without sacrificing memory efficiency.
7. Conclusion & Reflections
7.1. Conclusion Summary
This paper effectively revisits and reinforces the critical role of flat weights and activations in achieving effective Large Language Model (LLM) quantization. It highlights that existing pre-quantization transformations often fall short, still yielding steep and outspread distributions that hinder performance. In response, the authors introduce FlatQuant, a novel post-training quantization method built upon fast and learnable affine transformations. These transformations are optimized for each linear layer to actively promote the flatness of weights and activations.
Through extensive experiments, FlatQuant demonstrates clear superiority, establishing a new state-of-the-art benchmark for LLM quantization. It achieves a remarkably low accuracy drop of less than 1% for W4A4 quantization on the highly challenging LLaMA-3-70B model, significantly outperforming competitors like SpinQuant by 7.5%. Furthermore, its efficient kernel fusion strategy, which integrates affine transformation and quantization operations, delivers substantial inference speedups—up to 2.3x for prefill and 1.7x for decoding compared to the FP16 baseline. This work marks a significant advancement toward the practical application of high-accuracy, low-bit quantization for LLMs.
7.2. Limitations & Future Work
The authors acknowledge a limitation regarding the current scope of FlatQuant primarily focusing on INT4 quantization. They point out that FlatQuant has not yet been rigorously applied to hypothetical new data types, such as MXFP (Mixed-Precision Floating Point). Exploring FlatQuant's compatibility and potential advantages with these emerging data types represents a promising avenue for future research.
7.3. Personal Insights & Critique
FlatQuant offers several compelling insights and makes a strong case for the explicit optimization of distribution flatness in quantization. The primary innovation lies in its learnable affine transformations, which offer greater flexibility than fixed scaling or rotation methods, coupled with a highly engineered Kronecker product decomposition and kernel fusion for efficiency. This combination addresses the common trade-off between expressive power and computational cost in quantization.
Inspirations and Transferability:
- Explicit Flatness Objective: The idea of explicitly optimizing for
flatness(quantified by Euclidean distance to an ideal flat distribution) could be transferred to other compression domains beyondLLMs, such asvision transformersor other large neural networks whereoutlierspose a challenge. - Learnable Transformations with Efficiency Constraints: The methodology of
learningoptimal transformations under strictcomputational and memory constraints(viaKronecker productandkernel fusion) is a powerful paradigm. This could inspire similar approaches in other areas of model compression or hardware-aware neural network design, where custom, efficient operations are critical. - Robustness to Quantizer and Models:
FlatQuant's ability to perform well with simpleRTNquantization, and across diverse LLM families and tasks, suggests a strong underlying principle that makes the models inherently "quantization-friendly." This indicates a fundamental improvement in how data is represented for low-bit precision.
Potential Issues, Unverified Assumptions, or Areas for Improvement:
-
Generalizability of Kronecker Product: While effective for the tested dimensions, the optimal choice of and the generalizability of the
Kronecker productapproach for arbitrary layer dimensions or differentTransformervariants could be explored further. The paper states optimal are sought, but whether this decomposition always maintains expressiveness for drastically different architectures or very small dimensions (where might be too small) is worth investigating. -
Complexity of Learnable Transformations: While
FlatQuantboasts efficiency, learning these affine transformations adds a calibration step not present in zero-shotPTQmethods. The "hours" of calibration, while significantly less thanQuantization-Aware Training (QAT), might still be a factor for extremely rapid deployment or scenarios with very limited compute. Further reducing calibration time or exploring even more lightweight learning strategies could be beneficial. -
Impact of Calibration Data Quality: Although the paper shows robustness to different calibration datasets (WikiText2, C4, Pile), the quality and domain relevance of the calibration data can still be crucial. An
LLMfine-tuned for a highly specialized domain might require domain-specific calibration data for optimalFlatQuantperformance. -
Extreme Low-Bit Sensitivity: While
FlatQuantperforms exceptionally well atW4A4KV4and evenW3A3KV3, the performance gap toFP16inevitably widens atW2A2KV2. The "sweet spot" forpractical deploymentin terms of bit-width versus accuracy-speed trade-off remains a nuanced decision. Future work could focus on pushing the boundaries of even lower bit-widths while maintaining the same level of relative accuracy. -
Hardware Dependency of Kernel Fusion: The
kernel fusionbenefits are highly dependent on the targetGPU architecture(e.g.,NVIDIAwithCUDA/Triton). WhileTritonaims for portability,optimalitymight still vary across different hardware platforms (e.g.,AMD GPUs,custom AI accelerators). Ensuring broad hardware compatibility and performance for the custom kernels would be a continuous effort.Overall,
FlatQuantpresents a robust and highly effective solution to a critical problem inLLMdeployment. Its principled approach toflatnessand meticulously optimized implementation make it a standout contribution in the field ofquantization.
Similar papers
Recommended via semantic vector search.