The Devil in Linear Transformer
TL;DR Summary
This paper identifies unbounded gradients and attention dilution as key flaws in kernel-based linear transformers, then introduces TransNormer, which stabilizes training via normalized attention and enhances local focus with diagonal attention, achieving superior accuracy and eff
Abstract
Linear transformers aim to reduce the quadratic space-time complexity of vanilla transformers. However, they usually suffer from degraded performances on various tasks and corpus. In this paper, we examine existing kernel-based linear transformers and identify two key issues that lead to such performance gaps: 1) unbounded gradients in the attention computation adversely impact the convergence of linear transformer models; 2) attention dilution which trivially distributes attention scores over long sequences while neglecting neighbouring structures. To address these issues, we first identify that the scaling of attention matrices is the devil in unbounded gradients, which turns out unnecessary in linear attention as we show theoretically and empirically. To this end, we propose a new linear attention that replaces the scaling operation with a normalization to stabilize gradients. For the issue of attention dilution, we leverage a diagonal attention to confine attention to only neighbouring tokens in early layers. Benefiting from the stable gradients and improved attention, our new linear transformer model, transNormer, demonstrates superior performance on text classification and language modeling tasks, as well as on the challenging Long-Range Arena benchmark, surpassing vanilla transformer and existing linear variants by a clear margin while being significantly more space-time efficient. The code is available at https://github.com/OpenNLPLab/Transnormer .
Mind Map
In-depth Reading
English Analysis
Bibliographic Information
- Title: The Devil in Linear Transformer
- Authors: Zhen Qin, Xiaodong Han, Weixuan Sun, Dongxu Li, Lingpeng Kong, Nick Barnes, Yiran Zhong.
- Affiliations: SenseTime Research, Australian National University, OPPO Research Institute, Shanghai AI Laboratory, The University of Hong Kong.
- Journal/Conference: This paper was published as a preprint on arXiv. arXiv is a popular open-access repository for academic papers, often used to share research before or during the formal peer-review process. Its presence here indicates it's a work-in-progress or a pre-publication version.
- Publication Year: 2022 (Published on arXiv on October 19, 2022).
- Abstract: The paper investigates why linear transformers, designed to be more efficient than standard transformers, often underperform. The authors identify two primary problems: 1) unbounded gradients during training, which destabilize the model's convergence, and 2) attention dilution, where the model fails to focus on important local context in long sequences. To solve these, they propose
TransNormer, a new linear transformer model. It introducesNormAttention, which replaces the standard scaling operation with a normalization step to stabilize gradients, andDiagAttention, a block-based attention mechanism used in early layers to enforce a focus on neighboring tokens. The authors show thatTransNormersurpasses both vanilla transformers and other linear variants in performance on text classification, language modeling, and the Long-Range Arena benchmark, while being much more efficient. - Original Source Link: https://arxiv.org/abs/2210.10340
Executive Summary
- Background & Motivation (Why): The standard Transformer architecture, while powerful, has a major drawback: its self-attention mechanism has a computational and memory complexity that is quadratic () with respect to the input sequence length . This makes it prohibitively expensive for processing long documents, high-resolution images, or other lengthy data. Linear Transformers were proposed to solve this by reducing the complexity to linear (), but this efficiency often comes at the cost of a significant drop in model performance. The paper's motivation is to diagnose and fix the root causes of this performance degradation, bridging the gap between efficiency and effectiveness.
- Main Contributions / Findings (What):
- Identification of Two Core Problems: The paper is the first to systematically identify and analyze two specific "devils" in existing linear transformers:
- Unbounded Gradients: The authors theoretically prove that the scaling mechanism inherited from vanilla transformers causes gradients to become unstable and potentially infinite in linear attention, hindering model training.
- Attention Dilution: They empirically show that linear attention mechanisms tend to distribute attention scores too broadly and evenly across the entire sequence, losing the crucial ability to focus on local context, which is a key strength of the original Transformer.
- Novel Attention Mechanisms: To address these issues, the paper proposes two new components:
NormAttention: A new linear attention formulation that removes the problematic scaling factor and instead applies a normalization layer (likeRMSNorm) to the output. This simple change is shown to stabilize gradients and improve convergence.DiagAttention: A diagonal, block-based attention mechanism that forces early layers of the model to focus only on local, neighboring tokens, directly combating the attention dilution problem.
- The
TransNormerModel: The paper combines these two ideas into a new hybrid architecture calledTransNormer. It usesDiagAttentionin its initial layers to build strong local representations andNormAttentionin its later layers to efficiently integrate global information with stable training. This design achieves superior performance and efficiency, outperforming both standard Transformers and previous linear variants on a wide range of benchmarks.
- Identification of Two Core Problems: The paper is the first to systematically identify and analyze two specific "devils" in existing linear transformers:
Prerequisite Knowledge & Related Work
Foundational Concepts
To understand this paper, one must first be familiar with the following concepts:
- Transformer: A neural network architecture introduced in "Attention Is All You Need" (Vaswani et al., 2017). Its core innovation is the self-attention mechanism, which allows the model to weigh the importance of different words (or tokens) in a sequence when processing any given token.
- Self-Attention: This mechanism computes an output for each token by taking a weighted sum of all other tokens in the sequence. The weights, or "attention scores," are calculated based on the similarity between a token's
Query(Q) vector and another token'sKey(K) vector. The final output is the sum ofValue(V) vectors weighted by these scores. - Quadratic Complexity: In the standard self-attention mechanism, every token must be compared with every other token to compute the attention scores. For a sequence of length , this requires calculating an attention matrix, leading to a time and memory complexity of . This is the primary bottleneck for long sequences.
- Linear Transformer: A class of models that modify the self-attention mechanism to reduce its complexity to linear, i.e., . Most achieve this using the kernel trick. Instead of computing the full matrix
softmax(QKᵀ)V, they rewrite the computation as , where is a kernel function. By associativity, this can be reordered to , avoiding the creation of the large intermediate matrix and reducing complexity to . - Gradients: In deep learning, gradients are derivatives of the loss function with respect to the model's parameters. They indicate the direction and magnitude of change needed to improve the model's performance during training (optimization). Unbounded gradients (or exploding gradients) are gradients that become excessively large, causing training to become unstable and fail to converge.
Previous Works
The paper positions itself within the field of efficient transformers. It categorizes prior work into two main families:
- Pattern-based Methods: These methods sparsify the attention matrix by forcing each token to attend to only a subset of other tokens using predefined patterns. Examples include
Longformer(using sliding windows and dilated attention) andBigBird(using random, window, and global attention). While effective, these patterns are often handcrafted and may not be universally optimal. - Kernel-based Methods (Linear Transformers): This is the family this paper focuses on. These methods use kernel functions to approximate or replace the softmax function, enabling linear complexity. Examples include:
-
Performer: Uses random feature maps to approximate the softmax kernel. -
Linear Attention: Uses a simple activation function like as the kernel. -
cosFormer: Replaces softmax with a cosine-based attention mechanism.The key insight of "The Devil in Linear Transformer" is that while previous works focused on designing better kernel functions to approximate softmax, they overlooked fundamental optimization and architectural issues (unbounded gradients and attention dilution) that were holding back performance.
-
Differentiation
This paper differentiates itself by not just proposing another new kernel, but by performing a deep diagnostic analysis of why existing kernel-based methods fail. Its novelty lies in:
- Problem Diagnosis: It provides both theoretical proof (for unbounded gradients) and empirical evidence (for attention dilution) of fundamental flaws in the standard linear transformer design.
- Targeted Solutions: Instead of a monolithic new attention, it proposes two distinct, complementary solutions (
NormAttentionandDiagAttention) that directly target the identified problems. - Hybrid Architecture: It intelligently combines these solutions in a hybrid model (
TransNormer) that leverages the strengths of both local and global attention at different stages of processing.
Methodology (Core Technology & Implementation Details)
The authors first identify two critical flaws in existing linear transformers and then propose solutions that form the TransNormer model.
Problem 1: Unbounded Gradients
The authors show that the scaling operation in traditional attention, while benign in vanilla transformers, causes gradient instability in linear transformers.
A general attention mechanism can be written as: where is the attention score from token to token , and is the similarity score. The function differs between attention types:
-
Vanilla Transformer: Uses the softmax function. Here, and .
-
Linear Transformer: Uses a kernelized dot product. Here, and (identity function).
The paper derives the gradient of an attention score with respect to a similarity score : where is an indicator function (1 if , 0 otherwise).
-
For Vanilla Attention: Since , its derivative is . Thus, . The gradient is . Since , the gradient is bounded: This leads to stable training.
-
For Linear Attention: Since , its derivative is . The gradient becomes: The term is problematic. Since can be very close to zero during training, the gradient can become arbitrarily large (unbounded). This theoretical result (Proposition 3.1) explains the optimization instability observed in many linear transformers.
Problem 2: Attention Dilution
The paper observes that while vanilla transformers naturally learn to focus on local context (neighboring tokens), linear transformers tend to distribute their attention scores more uniformly across the entire sequence. This "dilutes" the attention, preventing the model from capturing important local structures, especially in early layers.
To quantify this, they introduce the locally accumulated attention score. For a token , this score measures the sum of attention probabilities it gives to tokens within a local neighborhood.
Figure 2 in the paper provides a clear visualization. The curve for the vanilla transformer rises sharply, showing that a small local neighborhood captures a large portion of the total attention score. In contrast, the curve for a linear transformer is much flatter, indicating attention is spread out.
Figure 2: (a): Comparison of locally accumulated attention scores of different transformer variants. The -axis denotesratio neibourhood izerelativ to the input ength; they-axis denotes accumulateattention sores insde thi neigbourhood or the centering tokenThe curve fo the vanilla transformer model increases more sharply, indicating that the attention scores are more concentrated.Our model greatly alleviates the atenton dluion issueor linar modes. Qualitativecomparison attention matrices neary model laers.The proposed TrANsNoRMER produces more similar patterns to the original vanilla transformer, benefiting to better colonteo eo and gets distracted by distant tokens in early layers.
Solution 1: NormAttention
To fix the unbounded gradients, the authors propose a simple yet effective solution: remove the denominator (scaling factor) from the linear attention computation and apply a standard normalization layer afterward.
The standard linear attention output is , where is a diagonal matrix that normalizes each row's attention scores to sum to 1. This is the source of the problematic term in the gradient.
The proposed NormAttention first computes the unnormalized output:
Then, it applies a normalization layer (e.g., LayerNorm or RMSNorm) to the result:
This design achieves two goals: the forward pass output remains bounded due to the normalization layer, and the backward pass gradients are also proven to be bounded, leading to stable training. Table 2 shows that the gradient deviation of NormAttention is much lower than other linear methods and comparable to the vanilla transformer.
Solution 2: DiagAttention
To combat attention dilution, the authors propose DiagAttention. The idea is to explicitly force the model to focus on local context in its early layers.
This is implemented as a block-based attention. The input sequence of length is divided into non-overlapping blocks of size . Self-attention is then computed only within each block. This prevents tokens from attending to distant tokens outside their block, forcing a local focus.
The computational complexity of this is . For a fixed block size and hidden dimension , this is linear with respect to the sequence length , preserving the efficiency of the model.
The TransNormer Architecture
TransNormer is a hybrid model that combines these two solutions. As shown in Figure 3, the architecture is split into two stages:
-
Early Layers: Use
DiagAttention. This encourages the model to first learn strong local features and representations, mimicking the behavior of vanilla transformers. -
Later Layers: Use
NormAttention. After local features are extracted, these layers are responsible for integrating information globally across the entire sequence.NormAttentionallows this to be done efficiently and with stable gradients.
Figure 3: Architecture overview of the proposed TRANsNorMER. In the early stages, we leverage DIAGATTENTION, where attention is only calculated inside the blocks to enforce neighbouring focus. In late stages, NoRMATTENTIoN assists to obtain a more stable gradients in linear complexity.
The authors empirically find that a 50/50 split (e.g., 6 DiagAttention layers followed by 6 NormAttention layers in a 12-layer model) works best.
Experimental Setup
Datasets
The authors evaluate TransNormer on a diverse set of tasks and datasets:
- WikiText-103: A large, high-quality language modeling dataset used for both autoregressive (predicting the next word) and bidirectional (masked language modeling) tasks.
- GLUE (General Language Understanding Evaluation) Benchmark: A collection of nine different natural language understanding tasks, including sentiment analysis (SST-2), question answering (QNLI), and textual similarity (MRPC). It is used to evaluate the fine-tuning performance of bidirectional models.
- Long-Range Arena (LRA) Benchmark: A benchmark specifically designed to test the ability of efficient transformers to handle very long sequences (up to 16K tokens) and model long-range dependencies across various data types (text, images, math).
Evaluation Metrics
The paper uses standard metrics for each task.
-
Perplexity (PPL): Used for language modeling. It measures how well a probability model predicts a sample. A lower PPL indicates the model is less "surprised" by the test data and thus has a better understanding of the language.
- Conceptual Definition: PPL is the exponentiated average negative log-likelihood of a sequence.
- Mathematical Formula: For a sequence of tokens ,
- Symbol Explanation:
- : Total number of tokens in the sequence.
- : The probability of the -th token, given the preceding tokens, as predicted by the model.
-
Accuracy: Used for classification tasks in GLUE (e.g., SST-2, QNLI).
- Conceptual Definition: The proportion of correctly classified examples.
- Mathematical Formula:
-
F1 Score: Used for MRPC in GLUE. It is the harmonic mean of precision and recall, providing a balanced measure for imbalanced datasets.
- Conceptual Definition: A single score that balances the trade-off between precision (how many selected items are relevant) and recall (how many relevant items are selected).
- Mathematical Formula:
- Symbol Explanation:
TP: True Positives,FP: False Positives,FN: False Negatives.
-
Matthews Correlation Coefficient (MCC): Used for CoLA in GLUE. It is considered a very robust metric for binary classification, as it accounts for all four values in the confusion matrix (TP, TN, FP, FN).
- Conceptual Definition: A correlation coefficient between the observed and predicted binary classifications. It returns a value between -1 and +1. +1 represents a perfect prediction, 0 no better than random, and -1 an inverse prediction.
- Mathematical Formula:
- Symbol Explanation:
TP, TN, FP, FNare True/False Positives/Negatives.
Baselines
The paper compares TransNormer against a comprehensive set of baselines, including:
- Vanilla Transformer: The original, non-linear model.
- Recent high-performing linear transformers:
FLASH,FLASH-quad,cosFormer. - Classic efficient transformers:
Performer,Linformer,Reformer. - Pattern-based transformers:
BigBird(on LRA). - Other efficient variants:
Transformer-LS,Nystromformer,Skyformer.
Results & Analysis
The paper presents compelling results across all experiments, demonstrating that TransNormer is both efficient and highly performant.
Core Results
-
Autoregressive Language Modeling (Table 4): On WikiText-103,
TransNormer T2achieves a test perplexity of 31.01, matching the vanilla Transformer and significantly outperforming all other linear transformer baselines likeFLASH(34.63) andPerformer(77.65).(Manual transcription of Table 4)
Method PPL (val) PPL (test) Params (m) Vanilla 29.63 31.01 156.00 LS 32.37 32.59 159.46 FLASH-quad 31.88 33.50 153.51 FLASH 33.18 34.63 153.52 1+elu 32.63 34.25 156.00 Performer 75.29 77.65 156.00 TransNormer T1 29.89 31.35 155.99 TransNormer T2 29.57 31.01 155.99 -
Bidirectional Language Modeling (Table 5): On the GLUE benchmark,
TransNormer T1achieves the highest average score (79.38) among all models, even surpassing the vanilla Transformer (78.79). It shows particularly strong gains on the challenging CoLA task.(Manual transcription of Table 5)
Method MNLI QNLI QQP SST-2 MRPC CoLA AVG Params (m) Vanilla 79.37/79.07 87.79 88.04 90.25 88.35 38.63 78.79 124.70 FLASH-quad 78.71/79.43 86.36 88.95 90.94 81.73 41.28 78.20 127.11 FLASH 79.45/80.08 87.10 88.83 90.71 82.50 29.40 76.87 127.12 LS 77.01/76.78 84.86 86.85 90.25 82.65 40.65 77.01 128.28 Performer 58.85/59.52 63.44 79.10 81.42 82.11 19.41 63.41 124.70 1+elu 74.87/75.37 82.59 86.9 87.27 83.03 - 70.00 124.0 TransNormer T1 79.06/79.93 87.00 88.61 91.17 84.50 45.38 79.38 124.67 TransNormer T2 77.28/78.53 85.39 88.56 90.71 85.06 45.90 78.78 124.67 -
Long-Range Arena (Table 6):
TransNormersets a new state-of-the-art on the LRA benchmark, withTransNormer T2achieving an average score of 64.80, outperforming all other efficient transformers, includingcosFormer(62.11) and the vanilla Transformer (57.37). This confirms its excellent ability to handle long-range dependencies.(Manual transcription of Table 6, showing selected rows for brevity)
Model Text ListOps Retrieval Pathfinder Image AVG. Transformer 61.95 38.37 80.69 65.26 40.57 57.37 Performer 64.19 38.02 80.04 66.30 41.43 58.00 cosFormer 67.70 36.50 83.15 71.96 51.23 62.11 FLASH 64.10 38.70 86.10 70.25 47.40 61.31 TransNormer T1 66.90 41.03 83.11 75.92 51.60 63.71 TransNormer T2 72.20 41.60 83.82 76.80 49.60 64.80 -
Speed Comparison (Table 7): The paper shows
TransNormeris significantly faster than both the vanilla Transformer and other competitive linear models. For a sequence of length 5K,TransNormer T2achieves a training speed of 10.16 steps/sec, whereasFLASHruns at 6.93 steps/sec and the vanilla Transformer runs out of memory. This highlights the model's practical efficiency.(Manual transcription of Table 7, selected columns)
model Train Speed 1K Train Speed 3K Train Speed 5K Transformer 15.34 - - FLASH 20.49 8.47 6.93 Performer 28.41 12.02 9.06 TransNormer T2 29.41 12.95 10.16
Ablations / Parameter Sensitivity
The ablation studies robustly validate the design choices of TransNormer.
-
Architecture Design (Tables 8 & 9): The hybrid structure is crucial. Using only
DiagAttentionor onlyNormAttentionfor all layers results in worse performance. Furthermore, usingDiagAttentionin early layers andNormAttentionin later layers is significantly better than the reverse, confirming the hypothesis that models should first build local representations before integrating them globally.(Manual transcription of Table 9)
Early stage Later stage T1 ppl(val) T2 ppl(val) NormAtt BlockAtt 4.13 4.21 BlockAtt NormAtt 3.82 3.81 -
NormAttention(Table 1): Simply removing the scaling from linear attention causes a catastrophic performance drop (PPL from 4.98 to 797.08). The proposedNormAttentionwithRMSNormfixes this and even slightly improves performance (PPL 4.94), proving that a normalization strategy is essential. -
DiagAttentionBlock Size (Table 11): Performance improves with a larger block size forDiagAttention, as it allows the model to see more local context. However, this comes at a computational cost. The authors choose a block size of 64 as a good trade-off.
Conclusion & Personal Thoughts
-
Conclusion Summary: The paper successfully identifies and provides solutions for two fundamental problems—unbounded gradients and attention dilution—that have plagued kernel-based linear transformers. By introducing
NormAttentionfor stable training andDiagAttentionfor focused local processing, the resultingTransNormermodel achieves a new state-of-the-art in the trade-off between efficiency and performance. It consistently outperforms previous linear models and often matches or exceeds the performance of the far more expensive vanilla Transformer across a range of challenging benchmarks. -
Limitations & Future Work: The authors acknowledge that their work is focused entirely on Natural Language Processing (NLP). They state that a key area for future work is to investigate whether the same issues of unbounded gradients and attention dilution exist in vision transformers and to validate their proposed solutions in computer vision domains.
-
Personal Insights & Critique:
- Strength in Diagnosis: The paper's primary strength lies not just in its proposed model, but in its clear and rigorous diagnosis of the underlying problems. The theoretical proof of unbounded gradients is a significant contribution that brings clarity to a field that was previously focused more on empirical trial-and-error with different kernels.
- Simplicity and Effectiveness: The proposed solutions,
NormAttentionandDiagAttention, are conceptually simple, easy to implement, and highly effective. This makes them likely to be widely adopted. The idea of replacing a problematic scaling factor with a standard normalization layer is particularly elegant. - Hybrid Architecture Intuition: The design of the
TransNormerarchitecture is well-motivated and intuitive. The idea that a model should first build an understanding of local patterns before reasoning about global relationships aligns with how humans process information and has proven effective in other domains like computer vision (e.g., small receptive fields in early CNN layers). - Open Questions: While
DiagAttention's fixed, non-overlapping blocks are effective and efficient, they are also rigid. This structure might struggle with documents where key related information falls just across a block boundary. Future work could explore more flexible or overlapping block strategies, though this would involve a trade-off with computational cost. Overall, this paper is a high-quality piece of research that makes a tangible and impactful contribution to the field of efficient transformers.
Similar papers
Recommended via semantic vector search.