Paper status: completed

MemoryFormer : Minimize Transformer Computation by Removing Fully-Connected Layers

Published:11/06/2024
Original LinkPDF
Price: 0.100000
Price: 0.100000
3 readers
This analysis is AI-generated and may not be fully accurate. Please refer to the original paper.

TL;DR Summary

MemoryFormer is a novel transformer architecture that reduces computational complexity by eliminating most fully-connected layers while retaining necessary multi-head attention operations, utilizing in-memory lookup tables and hash algorithms for dynamic vector retrieval, validat

Abstract

In order to reduce the computational complexity of large language models, great efforts have been made to to improve the efficiency of transformer models such as linear attention and flash-attention. However, the model size and corresponding computational complexity are constantly scaled up in pursuit of higher performance. In this work, we present MemoryFormer, a novel transformer architecture which significantly reduces the computational complexity (FLOPs) from a new perspective. We eliminate nearly all the computations of the transformer model except for the necessary computation required by the multi-head attention operation. This is made possible by utilizing an alternative method for feature transformation to replace the linear projection of fully-connected layers. Specifically, we first construct a group of in-memory lookup tables that store a large amount of discrete vectors to replace the weight matrix used in linear projection. We then use a hash algorithm to retrieve a correlated subset of vectors dynamically based on the input embedding. The retrieved vectors combined together will form the output embedding, which provides an estimation of the result of matrix multiplication operation in a fully-connected layer. Compared to conducting matrix multiplication, retrieving data blocks from memory is a much cheaper operation which requires little computations. We train MemoryFormer from scratch and conduct extensive experiments on various benchmarks to demonstrate the effectiveness of the proposed model.

Mind Map

In-depth Reading

English Analysis

1. Bibliographic Information

1.1. Title

The central topic of the paper is "MemoryFormer: Minimize Transformer Computation by Removing Fully-Connected Layers". This title directly indicates the paper's main contribution: proposing a new transformer architecture designed to reduce computational complexity by rethinking the role of fully-connected (FC) layers.

1.2. Authors

The authors are Ning Ding, Yehui Tang, Haochen Qin, Zhenli Zhou, Chao Xu, Lin Li, Kai Han, Heng Liao, and Yunhe Wang. Their affiliations include State Key Lab of General AI, School of Intelligence Science and Technology, Peking University, Huawei Noah's Ark Lab, and Huawei HiSilicon. This diverse authorship, including both academic institutions and industry labs (Huawei), suggests a strong blend of theoretical research and practical application expertise.

1.3. Journal/Conference

The paper was published at openreview.net, with a published date of 2024-11-06T00:00:00.000Z. The presence of a NeurIPS Paper Checklist strongly suggests it was submitted to or accepted by the Conference on Neural Information Processing Systems (NeurIPS), which is one of the most prestigious and influential conferences in the field of artificial intelligence and machine learning. Acceptance at NeurIPS signifies high quality and significant contribution to the field.

1.4. Publication Year

The publication year is 2024.

1.5. Abstract

The paper addresses the ever-increasing computational complexity of large language models (LLMs) based on the Transformer architecture. While previous efforts have focused on optimizing the multi-head attention (MHA) mechanism (e.g., linear attention, flash-attention), this work introduces MemoryFormer, a novel architecture that targets the fully-connected (FC) layers for significant FLOPs reduction.

The core methodology involves replacing the linear projection of FC layers with an alternative feature transformation method. Specifically, it constructs in-memory lookup tables storing discrete vectors. A hash algorithm is then used to dynamically retrieve a correlated subset of vectors based on the input embedding, which are then combined to form the output embedding. This process provides an estimation of matrix multiplication but with substantially lower computational cost, as retrieving data blocks from memory is much cheaper than performing matrix multiplication.

The MemoryFormer is trained from scratch, and extensive experiments on various benchmarks demonstrate its effectiveness, showcasing comparable performance with significantly less computation.

The original source link is https://openreview.net/pdf?id=04EC4ZnZJj. This appears to be the official link to the paper on the OpenReview platform, likely indicating its status as a conference submission.

2. Executive Summary

2.1. Background & Motivation

The Transformer model has revolutionized deep learning, particularly in natural language processing (NLP) with the advent of large language models (LLMs). However, the pursuit of higher performance has led to an exponential increase in model size and, consequently, computational complexity. This escalating demand for computing resources is a major obstacle to the broader application and popularization of LLMs.

Previous research has largely focused on optimizing the multi-head attention (MHA) mechanism, which is computationally expensive for long sequences. Techniques like linear attention and FlashAttention reduce the complexity of MHA from quadratic to sub-quadratic or linear with respect to sequence length. However, the authors observe that in most practical scenarios, the majority of computational complexity (measured in FLOPs) in a standard Transformer model comes from the fully-connected (FC) layers, not MHA, unless the sequence length is extremely long. For instance, MHA FLOPs are 2s2d2 s^2 d while FC layer FLOPs are 12sd212 s d^2. MHA only becomes dominant when sequence length ss is greater than 6d (six times the hidden size). For a typical LLM with hidden size d=4096d=4096, this means ss would need to exceed 24K tokens, which is often not the case in many applications.

Another key observation is the underutilization of CPU and RAM resources during deep neural network inference on GPUs. While GPUs handle parallel computation, CPUs and RAM (which can reach terabytes) often remain idle. This suggests an opportunity to leverage these underutilized resources.

The paper's entry point is to fundamentally rethink the fully-connected (FC) layers, which are responsible for the bulk of computation in Transformers for typical sequence lengths. Instead of optimizing MHA, MemoryFormer proposes to eliminate the matrix multiplication operations in FC layers by replacing them with a memory-centric retrieval mechanism.

2.2. Main Contributions / Findings

The primary contributions of this paper are:

  • Novel Architecture (MemoryFormer): Introduction of MemoryFormer, a new Transformer architecture that significantly reduces computational complexity (FLOPs) by replacing fully-connected (FC) layers with Memory Layers. This shifts the paradigm from computation-heavy matrix multiplication to memory-efficient data retrieval.
  • The Memory Layer: Development of the Memory Layer as an alternative for feature transformation. This layer constructs in-memory lookup tables of discrete vectors. It uses a hash algorithm (specifically, a simplified locality-sensitive hashing (LSH)-like function) to dynamically retrieve a correlated subset of vectors based on the input embedding. These retrieved vectors are then aggregated to estimate the output of a traditional FC layer.
  • Differentiable Memory Layer: Design of a method to make the vectors stored in the hash tables learnable via back-propagation, allowing the MemoryFormer to be trained from scratch in an end-to-end manner. This involves using a scaled cosine similarity and a softmax-like probability distribution to introduce differentiability for the hashing operation.
  • Significant FLOPs Reduction: The MemoryFormer successfully eliminates nearly all computations of the Transformer model except for those required by the multi-head attention (MHA) operation. The paper claims a MemoryFormer block requires only 19%\sim 19\% of the FLOPs compared to a baseline Transformer block for a sequence length of s=2048s=2048 and hidden size d=2048d=2048, with the reduction becoming more significant as model size scales up.
  • Comparable or Improved Performance: Extensive experiments on multiple NLP benchmarks (PIQA, WinoGrande, WSC, ARC-E, ARC-C, LogiQA) demonstrate that MemoryFormer achieves performance comparable to, and in some cases even better than, baseline Pythia models, while significantly reducing computation.
  • New Optimization Perspective: This work introduces a new FLOPs-reduction strategy that differs from existing approaches primarily focused on self-attention optimization, opening new avenues for LLM efficiency and providing potential guidance for future hardware design (e.g., larger bus width, higher cache hit rate).

3. Prerequisite Knowledge & Related Work

3.1. Foundational Concepts

To understand MemoryFormer, a basic grasp of the following concepts is essential:

  • Transformer Model: The Transformer is a neural network architecture introduced in 2017 by Vaswani et al. in "Attention Is All You Need." It primarily relies on self-attention mechanisms to process sequential data, rather than recurrent or convolutional layers. It consists of an encoder and a decoder (though many LLMs are decoder-only Transformers). Each encoder or decoder block typically contains a multi-head attention (MHA) sub-layer and a feed-forward network (FFN) sub-layer, with residual connections and layer normalization applied around them.

    • Input Embedding: Text tokens (words, subwords) are first converted into dense numerical vectors called embeddings.
    • Positional Encoding: Since Transformers do not inherently process sequences in order, positional encodings are added to embeddings to provide information about the token's position in the sequence.
  • Multi-Head Attention (MHA): MHA is the core mechanism of the Transformer. It allows the model to weigh the importance of different parts of the input sequence when processing each token. It involves projecting the input embeddings into three different matrices: Query (Q), Key (K), and Value (V). The attention score for each Query against all Keys determines how much attention should be paid to the corresponding Value. Multi-head means this process is done multiple times in parallel with different linear projections, and the results are concatenated and linearly transformed. The standard Scaled Dot-Product Attention is calculated as: $ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $ Where:

    • QQ is the query matrix.
    • KK is the key matrix.
    • VV is the value matrix.
    • dkd_k is the dimension of the key vectors.
    • softmax\mathrm{softmax} normalizes the scores. Multi-head attention runs this operation multiple times with different learned linear projections for Q, K, V, and concatenates the results.
  • Feed-Forward Network (FFN) / Fully-Connected (FC) Layers: After the MHA sub-layer, each position in the Transformer block passes through a position-wise feed-forward network. This FFN consists of two fully-connected (FC) layers with a non-linear activation function (like ReLU or GELU) in between. Importantly, this FFN is applied identically and independently to each position's embedding. It is typically designed to expand the dimensionality in the first FC layer (e.g., by 4 times the hidden size) and then reduce it back in the second FC layer. The core operation within an FC layer is a linear projection, which involves matrix multiplication of the input vector by a weight matrix.

  • Computational Complexity (FLOPs): FLOPs stands for Floating-Point Operations per Second, or simply Floating-Point Operations. It's a common metric used to quantify the computational cost of a model. A higher FLOPs count generally means more computation is required, leading to longer inference/training times and higher energy consumption. Reducing FLOPs is crucial for making LLMs more efficient and accessible.

  • Locality-Sensitive Hashing (LSH): LSH is a technique used for efficiently finding approximate nearest neighbors in high-dimensional spaces. Unlike traditional cryptographic hash functions, which aim to minimize collisions, LSH functions are designed to maximize the probability of collision for similar items and minimize it for dissimilar items. This means that two data points that are "close" in the original space are likely to be mapped to the same "bucket" by the LSH function. This property makes it useful for tasks like similarity search, clustering, and data deduplication.

  • Back-propagation: Back-propagation is the algorithm used to train artificial neural networks. It computes the gradient of the loss function with respect to the weights of the network. These gradients indicate how much each weight should be adjusted to reduce the loss. The process involves a forward pass (computing predictions) and a backward pass (computing and propagating gradients from the output layer back through the network). For a component to be trainable with back-propagation, its operations must be differentiable.

3.2. Previous Works

The paper contextualizes MemoryFormer within two main lines of research:

  • Efficient Transformers (Focus on Attention): Many prior works have focused on making the multi-head attention (MHA) mechanism more efficient, primarily by reducing its quadratic complexity with respect to sequence length ss.

    • Reformer [17] and YOSO [28]: These models explicitly use LSH to approximate the self-attention mechanism, grouping similar queries and keys into buckets to reduce the number of comparisons.
    • CosFormer [22], Performer [8], Linformer [26]: These methods propose different approximations or modifications to the softmax operation in attention or use linear projections to achieve sub-quadratic or linear complexity. For example, Linformer uses a linear projection to reduce the sequence length of Key and Value matrices before attention calculation.
    • Sliding Windows [14]: This approach constrains the attention map to a local range, allowing each token to attend only to its neighbors within a fixed window, thereby reducing complexity.
    • FlashAttention [10]: A practical engineering method that optimizes the self-attention mechanism by reordering operations and using GPU memory hierarchies more efficiently to reduce memory I/O operations, leading to significant speedups without changing the theoretical complexity.
  • Optimizing Feed-Forward Networks (FFN) / Fully-Connected (FC) Layers: A smaller body of work has addressed the FFN part of Transformers:

    • Sparsity Exploitation [13, 20, 27]: These research efforts exploit the sparsity of intermediate activations in the FFN module to reduce computation. For instance, some activations might be zero, meaning their subsequent operations can be skipped.
    • LookupFFN [29]: This work specifically applies LSH to accelerate the inference speed of the Feed-Forward Network, similar in spirit to MemoryFormer but with potentially different design choices and theoretical underpinnings for the lookup mechanism.
    • SLIDE [7] and MONGOOSE [6]: These works utilize LSH to improve the convergence speed of neural network training processes, often by efficiently sampling gradients from vast datasets or model parameters.

3.3. Technological Evolution

The evolution of Transformer models has been characterized by a relentless drive for scale, leading to LLMs with billions or even trillions of parameters. This scaling, while unlocking unprecedented capabilities, has also pushed the boundaries of available computational resources. Initially, efficiency efforts focused on MHA because its quadratic complexity O(s2d)O(s^2 d) made it a bottleneck for long sequences. However, as hidden dimensions dd also grew and many practical applications didn't always involve extremely long sequences (where s>6ds > 6d), the FC layers with their O(sd2)O(s d^2) complexity started becoming the dominant computational burden.

This paper's work (MemoryFormer) represents a new frontier in Transformer efficiency research. Instead of incrementally optimizing existing matrix multiplication operations, it fundamentally questions the necessity of matrix multiplication for feature transformation in FC layers. By proposing a memory-centric retrieval approach, it leverages an underutilized resource (RAM and CPU) and shifts the computational burden from FLOPs to memory access, fitting within the broader trend of hardware-aware model design and efficiency.

3.4. Differentiation Analysis

Compared to the main methods in related work, the core differences and innovations of MemoryFormer are:

  • Target of Optimization: Most efficient Transformer works (e.g., Linformer, Cosformer, Performer, FlashAttention) primarily focus on optimizing the multi-head attention (MHA) mechanism. MemoryFormer, however, directly targets the fully-connected (FC) layers within the Transformer block, which the authors argue constitute the majority of FLOPs in typical LLM applications. This is a crucial distinction as it addresses a different, yet often dominant, computational bottleneck.

  • Nature of Optimization:

    • Existing methods: Generally aim to make matrix multiplications in MHA faster or less complex (e.g., by approximating softmax or reducing sequence length through projection).
    • MemoryFormer: Replaces the matrix multiplication operation entirely with a memory lookup and aggregation process. This is a qualitative shift from computation to retrieval, leveraging locality-sensitive hashing (LSH).
  • Resource Utilization: MemoryFormer is explicitly designed to trade memory resources (potentially vast RAM on a system) for reduced computational complexity (FLOPs). This contrasts with GPU-centric optimizations that focus on maximizing GPU compute efficiency. The authors highlight the underutilization of CPU and RAM in current GPU-dominated inference, proposing to harness these.

  • Gradient Flow for LSH: While LookupFFN also uses LSH for FFN acceleration, MemoryFormer introduces a specific, differentiable mechanism (scaled cosine similarity and softmax-like probability p(zk)p(\mathbf{z}_k)) to allow end-to-end training of the lookup tables via back-propagation, which is a key innovation for making the Memory Layer trainable from scratch.

  • Compute-less claim: The paper emphasizes that retrieving data blocks from memory is a much cheaper operation which requires little computations compared to matrix multiplication. This forms the basis of its "minimize computation" claim, pushing the primary FLOPs burden almost entirely to the MHA part of the Transformer.

4. Methodology

4.1. Principles

The core principle of MemoryFormer is to replace the computationally intensive linear projection operations performed by fully-connected (FC) layers in a standard Transformer with a significantly cheaper memory retrieval process. This is achieved by storing pre-learned "responses" (vectors) in in-memory lookup tables and using a locality-sensitive hashing (LSH)-like mechanism to dynamically retrieve the most relevant vectors based on the input. The retrieved vectors are then aggregated to form the output, effectively estimating the result of matrix multiplication with minimal FLOPs. The design also ensures that these lookup tables are fully end-to-end trainable using back-propagation.

4.2. Core Methodology In-depth (Layer by Layer)

4.2.1. Standard Fully-Connected Layer

In a standard Transformer block, a fully-connected layer takes an input token embedding x\mathbf{x} and transforms it into an output token embedding y\mathbf{y} via matrix multiplication. Given an input row vector xRd\mathbf{x} \in \mathbb{R}^d representing a token embedding, and a weight matrix WRd×h\mathbf{W} \in \mathbb{R}^{d \times h}, the operation is formulated as: $ \mathbf{y} = \mathbf{x} \mathbf{W} $ Where yRh\mathbf{y} \in \mathbb{R}^h is the output token embedding. For a sequence of ss tokens, represented by a matrix XRs×d\mathbf{X} \in \mathbb{R}^{s \times d}, this becomes: $ \mathbf{Y} = \mathbf{X} \mathbf{W} $ with YRs×h\mathbf{Y} \in \mathbb{R}^{s \times h}. The computational complexity of this operation is O(sdh)\mathcal{O}(s d h). The paper's goal is to find an alternative mapping function that achieves similar properties (e.g., similar inputs yielding similar outputs) but with much lower computational complexity.

4.2.2. Compute-less Locality-Sensitive Hashing

The MemoryFormer's Memory Layer aims to replace the FC layer using locality-sensitive hashing (LSH). The idea is that if an input vector x\mathbf{x} is similar to other vectors, they should ideally map to the same hash bucket, from which a pre-stored vector y^\hat{\mathbf{y}} (approximating xW\mathbf{x}\mathbf{W}) can be retrieved.

Initial LSH Formulation: The paper starts by proposing a simple LSH function to generate a hash code for an input vector xRd\mathbf{x} \in \mathbb{R}^d. This hash code is then used as an index to retrieve a vector from a hash table T\mathbf{T}. The process is formulated as: $ \begin{array}{l} { \displaystyle h ( { \bf x } ) = \mathrm { i n t e g e r } ( \mathrm { s i g n } ( { \bf x } ) ) } , \ { \displaystyle \mathrm { s i g n } ( [ { \bf x } ] _ { i } ) = \left{ \begin{array} { l } { { \displaystyle - 1 , \mathrm { i f } [ { \bf x } ] _ { i } < 0 , } } \ { { \displaystyle 1 , \mathrm { i f } [ { \bf x } ] _ { i } \ge 0 , } } \end{array} \right. } \ { \displaystyle \mathrm { i n t e g e r } ( { \bf s } ) = \sum _ { i = 0 } ^ { d - 1 } \frac { [ { \bf s } ] _ { i } + 1 } { 2 } \cdot 2 ^ { i } } , \ { \displaystyle \hat { \bf y } = [ { \bf T } ] _ { h ( { \bf x } ) } } , \end{array} $ Where:

  • xRd\mathbf{x} \in \mathbb{R}^d is the input vector.
  • sign()\mathrm{sign}(\cdot) is an element-wise function that returns -1 if the input element is negative and 1 if it's non-negative. This converts the real-valued vector x\mathbf{x} into a binary representation (hash code) s{1,1}d\mathbf{s} \in \{-1, 1\}^d.
  • integer(s) converts this binary representation s\mathbf{s} into a non-negative integer. Each element [s]i[ \mathbf{s} ]_i (which is either -1 or 1) is first transformed to 0 or 1 (by [s]i+12\frac{[ \mathbf{s} ]_i + 1}{2}) and then combined as bits in a binary number, weighted by powers of 2. This integer h(x)h(\mathbf{x}) serves as the index for the hash table.
  • h(x){0,1,2,,2d1}h(\mathbf{x}) \in \{0, 1, 2, \dots, 2^d - 1\} is the computed hash index.
  • y^\hat{\mathbf{y}} is the retrieved vector from the hash table T\mathbf{T}, specifically the row indexed by h(x)h(\mathbf{x}).

Addressing Space Complexity: The space complexity of such a single hash table TR2d×h\mathbf{T} \in \mathbb{R}^{2^d \times h} is O(2dh)\mathcal{O}(2^d h). For typical Transformer hidden sizes (e.g., d=512d=512), 25122^{512} is an astronomically large number, making this approach impractical.

To solve this, the paper proposes splitting the input vector x\mathbf{x} into KK non-overlapping chunks: $ \mathbf{z}_k = \mathrm{split}(\mathbf{x}, \mathrm{num_chunk} = K), k = 1, 2, \dots, K $ Where:

  • zkRτ\mathbf{z}_k \in \mathbb{R}^{\tau} is the kk-th sub-vector.

  • τ=dK\tau = \frac{d}{K} is the dimension of each sub-vector.

  • dd must be evenly divisible by KK.

    For each sub-vector zk\mathbf{z}_k, a separate hash table TkR2τ×h\mathbf{T}_k \in \mathbb{R}^{2^\tau \times h} is constructed. The output y^\hat{\mathbf{y}} is then obtained by summing the retrieved vectors from each table: $ \hat{\mathbf{y}} = \sum_{k=1}^K [\mathbf{T}k]{h(\mathbf{z}_k)} $ The space complexity of this chunked approach becomes O(K2τh)\mathcal{O}(K 2^\tau h). By setting τ\tau to a small number (e.g., τ=8\tau=8 for d=512,K=64d=512, K=64), the memory requirement becomes manageable (e.g., 16\sim 16 MB for float16 data type). Figure 2 (from the original paper) visually demonstrates this chunking and hashing process, showing different sub-vectors zk\mathbf{z}_k mapping to different buckets in their respective tables Tk\mathbf{T}_k.

Figure 2: A demonstration with \(\\tau = 2\) and \(K = 3\) , where \(\\mathbf { z } _ { 1 }\) is hashed to the bucket2 of \({ \\bf T } _ { 1 }\) , \(\\mathbf { z } _ { 2 }\) is hashed to the bucket1 of \(\\mathbf { T } _ { 2 }\) , \(\\mathbf { z } _ { 3 }\) is hashed to the bucket2 of \(\\mathbf { T } _ { 3 }\) . 该图像是一个示意图,展示了在 au=2 au = 2K=3K = 3 的情况下,不同向量 oldsymbol{z}_1, oldsymbol{z}_2, oldsymbol{z}_3 被哈希到不同的桶中。具体而言,oldsymbol{z}_1 被哈希到桶2,oldsymbol{z}_2 被哈希到桶1,oldsymbol{z}_3 被哈希到桶2。

Figure 2: A demonstration with τ=2\tau = 2 and K=3K = 3 , where z1\mathbf { z } _ { 1 } is hashed to the bucket2 of T1{ \bf T } _ { 1 } , z2\mathbf { z } _ { 2 } is hashed to the bucket1 of T2\mathbf { T } _ { 2 } , z3\mathbf { z } _ { 3 } is hashed to the bucket2 of T3\mathbf { T } _ { 3 } .

4.2.3. Memory Layer with Differentiability

The formulation in the previous section (Eq. (7)) allows for a forward pass and enables updating the hash table values via back-propagation. However, the hashing operation h(zk)h(\mathbf{z}_k) itself (converting a real-valued vector to an integer index) is non-differentiable. This means the input vector x\mathbf{x} (and thus zk\mathbf{z}_k) would not receive gradients, preventing end-to-end training of the Transformer's upstream layers.

To address this, the authors introduce a differentiable weighting mechanism for each retrieved item. Instead of simply summing, they add a coefficient p(zk)p(\mathbf{z}_k) to weight each retrieved item based on its relevance to the input sub-vector: $ \hat{\mathbf{y}} = \sum_{k=1}^K p(\mathbf{z}_k) \cdot [\mathbf{T}k]{h(\mathbf{z}_k)} $ The coefficient p(zk)p(\mathbf{z}_k) is a function of zk\mathbf{z}_k that measures the relevance between zk\mathbf{z}_k and its corresponding hash bucket h(zk)h(\mathbf{z}_k).

Scaled Cosine Similarity: They use a scaled cosine similarity to measure this relevance, considering both the direction and amplitude of zk\mathbf{z}_k. This is defined as the inner product between zk\mathbf{z}_k and its binarized sign vector: $ \mathrm{sim}(\mathbf{z}_k, h(\mathbf{z}_k)) = |\mathbf{z}_k|_2 \cdot |\mathrm{sign}(\mathbf{z}_k)|_2 \cdot \mathrm{cos}(\mathbf{z}_k, \mathrm{sign}(\mathbf{z}_k)) = \langle \mathbf{z}_k, \mathrm{sign}(\mathbf{z}_k) \rangle $ Where:

  • ,\langle \cdot, \cdot \rangle computes the inner product of two vectors.
  • sign(zk)\mathrm{sign}(\mathbf{z}_k) is the binarized vector (each entry is -1 or 1) corresponding to the hash bucket h(zk)h(\mathbf{z}_k).

Probability Distribution p(zk)p(z_k): To make the process differentiable for the input zk\mathbf{z}_k, they define p(zk)p(\mathbf{z}_k) as a softmax-like probability that zk\mathbf{z}_k is mapped to its specific hash bucket h(zk)h(\mathbf{z}_k), considering all possible 2τ2^\tau buckets simultaneously: $ p(\mathbf{z}_k) = \frac{\exp[\mathrm{sim}(\mathbf{z}_k, h(\mathbf{z}k))/t]}{\sum{i=0}^{2^\tau-1} \exp[\mathrm{sim}(\mathbf{z}_k, i)/t]} = \frac{\exp[\langle \mathbf{z}_k, \mathrm{sign}(\mathbf{z}k) \rangle / t]}{\sum{i=0}^{2^\tau-1} \exp[\langle \mathbf{z}k, \mathrm{integer}\tau^{-1}(i) \rangle / t]} $ Where:

  • tt is the temperature hyper-parameter, controlling the sharpness of the distribution.
  • integerτ1(i){1,1}τ\mathrm{integer}_\tau^{-1}(i) \in \{-1, 1\}^\tau is a function that maps a non-negative integer ii (from 0 to 2τ12^\tau-1) to its corresponding τ\tau-bit binary representation (with bits represented as -1 or 1). This represents the sign vector for the ii-th bucket.

Simplification of p(zk)p(z_k): The denominator i=02τ1exp[zk,integerτ1(i)/t]\sum_{i=0}^{2^\tau-1} \exp[\langle \mathbf{z}_k, \mathrm{integer}_\tau^{-1}(i) \rangle / t] can be simplified. First, the numerator's inner product can be rewritten: $ \langle \mathbf{z}_k, \mathrm{sign}(\mathbf{z}k) \rangle = \sum{j=0}^{\tau-1} |[\mathbf{z}_k]_j| $ This is because if [zk]j0[\mathbf{z}_k]_j \ge 0, then [sign(zk)]j=1[\mathrm{sign}(\mathbf{z}_k)]_j = 1, and [zk]j1=[zk]j[\mathbf{z}_k]_j \cdot 1 = |[\mathbf{z}_k]_j|. If [zk]j<0[\mathbf{z}_k]_j < 0, then [sign(zk)]j=1[\mathrm{sign}(\mathbf{z}_k)]_j = -1, and [zk]j(1)=[zk]j[\mathbf{z}_k]_j \cdot (-1) = |[\mathbf{z}_k]_j|. So, the sum is simply the sum of the absolute values of the elements of zk\mathbf{z}_k.

The denominator can be expanded: $ \sum_{i=0}^{2^\tau-1} \exp[\langle \mathbf{z}k, \mathrm{integer}\tau^{-1}(i) \rangle] = \prod_{j=0}^{\tau-1} [\exp([\mathbf{z}_k]_j) + \exp(-[\mathbf{z}k]j)] $ This is a standard identity for summing over all 2τ2^\tau combinations of signs. When division by tt is included in the exponent: $ \sum{i=0}^{2^\tau-1} \exp[\langle \mathbf{z}k, \mathrm{integer}\tau^{-1}(i) \rangle / t] = \prod{j=0}^{\tau-1} [\exp([\mathbf{z}_k]_j / t) + \exp(-[\mathbf{z}_k]_j / t)] $ Using these simplifications, p(zk)p(\mathbf{z}_k) can be further simplified to: $ p(\mathbf{z}k) = \frac{\exp(\sum{j=0}^{\tau-1} |[\mathbf{z}_k]j| / t)}{\prod{j=0}^{\tau-1} [\exp([\mathbf{z}_k]_j / t) + \exp(-[\mathbf{z}_k]j / t)]} = \frac{1}{\prod{j=0}^{\tau-1} [1 + \exp(-2 |[\mathbf{z}_k]_j| / t)]} $ This simplified form for p(zk)p(\mathbf{z}_k) ensures differentiability and allows gradients to flow back to zk\mathbf{z}_k.

Final Memory Layer Formulation: Combining this probability with the retrieved vectors, the final Memory Layer output is: $ \mathbf{y} = \sum_{k=1}^K p(\mathbf{z}_k) \cdot [\mathbf{T}k]{h(\mathbf{z}k)} = \sum{k=1}^K \frac{[\mathbf{T}k]{h(\mathbf{z}k)}}{\prod{j=0}^{\tau-1} [1 + \exp(-2 |[\mathbf{z}_k]_j| / t)]} $ The left part of Figure 3 (from the original paper) illustrates the schematic of this Memory Layer.

Figure 3: Left: The schematic diagram of the Memory Layer. Right: One building block of the MemoryFormer. 该图像是示意图,左侧展示了memory layer的结构,而右侧展示了MemoryFormer的一个模块。左侧的计算流程包括三个处理单元,输入和输出经过加权和计算,并结合记忆块的操作,相关公式为 σ(QKT)Vσ(QK^T)V

Figure 3: Left: The schematic diagram of the Memory Layer. Right: One building block of the MemoryFormer. The computational complexity of a Memory Layer for a sequence of ss tokens, given output dimension hh, is approximately O(s(τ+h)K)\mathcal{O}(s (\tau + h) K), which simplifies to O(sdhτ)\mathcal{O}(\frac{s d h}{\tau}) when considering d=Kτd = K\tau. This is an order of magnitude smaller than the O(sdh)\mathcal{O}(s d h) of a fully-connected layer.

Gradient Calculation: The Memory Layer is designed to be fully differentiable. The gradients of the loss function LL with respect to the hash tables and the input vector are: $ \begin{array}{r l} & { \displaystyle \frac { \partial L } { \partial [ \mathbf { T } _ { k } ] _ { i } } = \left{ \begin{array} { c } { p ( \mathbf { z } _ { k } ) \frac { \partial L } { \partial \mathbf { y } } , \mathrm { ~ i f ~ } h ( \mathbf { z } _ { k } ) = i , } \ { 0 , \mathrm { ~ i f ~ } h ( \mathbf { z } _ { k } ) \neq i , } \end{array} \right. \ i \in { 0 , 1 , \cdots , 2 ^ { \tau } - 1 } , } \ & { \displaystyle \frac { \partial L } { \partial \mathbf { x } } = \mathrm { concat } ( \left[ \mathbf { \cdot } \mathbf { . } \cdot \frac { \partial L } { \partial \mathbf { y } } [ \mathbf { T } _ { k } ] _ { h ( \mathbf { z } _ { k } ) } ^ { \top } \frac { \partial p ( \mathbf { z } _ { k } ) } { \partial \mathbf { z } _ { k } } \dots \mathrm { for } k \mathrm { in } \mathrm { range } ( 1 , K + 1 ) \right] ) . } \end{array} $ Where:

  • L[Tk]i\frac{\partial L}{\partial [\mathbf{T}_k]_i} represents the gradient for the ii-th entry in the kk-th hash table. It's non-zero only for the bucket that zk\mathbf{z}_k was hashed to, scaled by p(zk)p(\mathbf{z}_k) and the gradient from the output Ly\frac{\partial L}{\partial \mathbf{y}}. This indicates that gradients to the hash tables are sparse.
  • Lx\frac{\partial L}{\partial \mathbf{x}} represents the gradient for the input vector. It's computed by concatenating the gradients contributed by each sub-vector zk\mathbf{z}_k, which involves the gradient of p(zk)p(\mathbf{z}_k) with respect to zk\mathbf{z}_k. This ensures end-to-end differentiability.

4.2.4. Architecture of MemoryFormer

MemoryFormer follows the standard Transformer design of NN stacked blocks. The right part of Figure 3 (from the original paper) depicts one building block.

Multi-Head Attention (MHA) in MemoryFormer:

  • Input sequence X=(x1,x2,,xs)Rs×d\mathbf{X} = (\mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_s)^\top \in \mathbb{R}^{s \times d} is first normalized by a Norm() layer.
  • Instead of traditional linear projections (which are FC layers), three Memory Layers (denoted MemoryLayerQMemoryLayer_Q, MemoryLayerKMemoryLayer_K, MemoryLayerVMemoryLayer_V) are used to transform the normalized X\mathbf{X} into Query (Q), Key (K), and Value (V) matrices, respectively.
  • The tokens in QQ, KK, VV are then split into multiple sub-vectors for multi-head processing, as in standard MHA.
  • The actual calculation of multi-head attention itself remains unchanged from the original Transformer architecture. This means any existing efficient self-attention techniques (e.g., FlashAttention, Linear Attention, KV-Cache) can be seamlessly integrated.
  • The output of the MHA is added to the input X\mathbf{X} via a residual connection, and then typically normalized.

Memory Block (Replacing Feed-Forward Network):

  • In MemoryFormer, the Feed-Forward Network (FFN) is replaced by a Memory Block.
  • A Memory Block consists of two consecutive Memory Layers.
  • Each Memory Layer is preceded by a Norm() layer. The Norm() layer is crucial here because it sets the input embedding to have a zero-mean distribution. This helps the sign() function (Eq. (3)) generate -1 and +1+1 values evenly, leading to a more uniform distribution of hash bucket retrievals, which is beneficial for model capacity.
  • No Intermediate Activation Function: Unlike standard FFNs which have an activation function (e.g., ReLU, GELU) between their two FC layers, MemoryFormer's Memory Block omits this. The authors argue that the hashing operation itself is non-linear, making an extra non-linear function redundant. Experiments confirm this has no adverse effect on performance.
  • Dimensionality Expansion: To maintain compatibility with the common FFN design pattern of expanding dimensionality, the first Memory Layer in the Memory Block expands its output dimensionality. If the input hidden size is d=τKd = \tau K, the output dimensionality of the first Memory Layer is set to (τ+2)K(\tau+2) \cdot K. This means the hash tables in this layer, Tk1\mathbf{T}_k^1, are of size R2τ×(τ+2)K\mathbb{R}^{2^\tau \times (\tau+2)K}.
  • Consequently, the sub-vectors zk\mathbf{z}_k feeding into the second Memory Layer of the Memory Block will have a bit width of τ=τ+2\tau' = \tau+2 (2 bits larger). The hash tables in the second layer, Tk2\mathbf{T}_k^2, are of size R2(τ+2)×d\mathbb{R}^{2^{(\tau+2)} \times d}, restoring the output dimensionality back to dd. This expansion increases the capacity of the second Memory Layer by a factor of 22=42^2=4.

Overall Computational Flow for One MemoryFormer Block: $ \mathbf{X} = \mathrm{Norm}(\mathbf{X}) $ $ \mathrm{Q} = \mathrm{MemoryLayer}_Q(\mathbf{X}), \mathrm{K} = \mathrm{MemoryLayer}_K(\mathbf{X}), \mathrm{V} = \mathrm{MemoryLayer}_V(\mathbf{X}) $ $ \mathbf{Z} = \mathbf{X} + \mathrm{MultiHeadAttention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) $ $ \mathbf{Y} = \mathbf{Z} + \mathrm{MemoryLayer}_2(\mathrm{Norm}(\mathrm{MemoryLayer}_1(\mathrm{Norm}(\mathbf{Z})))) $ This represents the typical residual connections and layer normalization around the MHA and Memory Block.

Computational Complexity Comparison:

  • Standard Transformer Block: The total FLOPs are approximately 2s2d+12sd22s^2d + 12sd^2. (Here, 2s2d2s^2d is for MHA and 12sd212sd^2 for all FC layers, including MHA projections and FFN).
  • MemoryFormer Block: The total FLOPs are approximately 2s2d+6τsd2=2s2d+6Ksd2s^2d + \frac{6}{\tau}sd^2 = 2s^2d + 6Ksd. (Here, 2s2d2s^2d is for MHA, and 6τsd2\frac{6}{\tau}sd^2 or 6Ksd is for all Memory Layers). The computations originating from FC layers in standard transformer are eliminated by an order of magnitude. This means the absolute majority of the computational workload now comes from the MultiHeadAttention.

Figure 1 (from the original paper) visually illustrates this FLOPs reduction. The blue lines represent a traditional Transformer, while the red stars represent MemoryFormer. It clearly shows MemoryFormer having significantly lower FLOPs for the same hidden size and sequence length.

Figure 1: FLOPs with different model hidden size and sequence lengths. 该图像是图表,展示了不同模型隐藏层大小和序列长度下的推理 FLOPs(每块计算量)。横轴为模型隐藏层大小,纵轴为 FLOPs,蓝色曲线表示传统 Transformer,红色星形表示 MemoryFormer。可以观察到,MemoryFormer 在相同条件下显著降低了计算复杂度。

Figure 1: FLOPs with different model hidden size and sequence lengths.

5. Experimental Setup

5.1. Datasets

The authors use The Pile dataset for training and evaluate MemoryFormer on six widely-used NLP benchmarks.

  • Training Data:

    • The Pile [12]: This is a large, diverse corpus used for language modeling. It comprises 825 GiB of text from 22 distinct high-quality subsets, covering professional and academic fields. Its diversity is crucial for training robust large language models.
  • Evaluation Benchmarks (Zero-shot evaluation): These tasks are chosen to provide a comprehensive assessment of the LLM's capabilities, ranging from knowledge-based questions to reasoning.

    • PIQA [4]: A physical commonsense reasoning dataset. It asks common sense questions about everyday situations, requiring the model to choose the more plausible of two possible solutions.
    • WinoGrande [23, 24]: An adversarial Winograd Schema Challenge. It's a large-scale dataset of commonsense reasoning problems, designed to be difficult for models that rely on statistical biases rather than true understanding.
    • WSC [24]: Winograd Schema Challenge. Similar to WinoGrande, it focuses on coreference resolution problems that require commonsense reasoning.
    • ARC-E (AI2 Reasoning Challenge - Easy) & ARC-C (AI2 Reasoning Challenge - Challenging) [9]: These datasets test a model's ability to answer science questions, categorizing them by difficulty. They require knowledge retrieval and reasoning.
    • LogiQA [19]: A machine reading comprehension dataset with logical reasoning requirements. It focuses on testing a model's ability to perform various types of logical inferences.

5.2. Evaluation Metrics

The paper primarily uses FLOPs for computational efficiency and accuracy (or Val. PPL for ablation) for performance.

  • Floating-Point Operations (FLOPs):

    • Conceptual Definition: FLOPs are a measure of the total number of floating-point arithmetic operations (like additions, subtractions, multiplications, divisions) performed by a computational model or algorithm. It quantifies the computational cost, directly correlating with the energy consumption and execution time of a model. In this paper, FLOPs are reported for one Transformer block with a given sequence length and hidden size, allowing for a direct comparison of the computational efficiency of different architectures.
    • Mathematical Formula: FLOPs are calculated by summing the operations of all individual components (e.g., matrix multiplications, element-wise operations). There isn't a single universal formula, but rather a calculation specific to the operations in a neural network. For example, a matrix multiplication of A (m x n) by B (n x p) involves mp(2n1)m * p * (2n - 1) FLOPs (approximately 2mnp for large nn).
  • Accuracy:

    • Conceptual Definition: Accuracy is a common metric for classification tasks. It measures the proportion of total predictions that were correct. It provides a straightforward indication of how well the model is performing on a given dataset, particularly when classes are balanced. In the context of NLP benchmarks like PIQA, WinoGrande, WSC, ARC-E, ARC-C, and LogiQA, it typically refers to the percentage of correctly answered questions or correctly identified instances.
    • Mathematical Formula: $ \text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}} $
    • Symbol Explanation:
      • Number of Correct Predictions: The count of instances where the model's output matches the ground truth label.
      • Total Number of Predictions: The total number of instances evaluated by the model.
  • Perplexity (PPL):

    • Conceptual Definition: Perplexity is a widely used metric for evaluating language models. It quantifies how well a probability model predicts a sample. A lower perplexity score indicates that the model is better at predicting the next word in a sequence, implying a better language model. It can be thought of as the inverse probability of the test set, normalized by the number of words.
    • Mathematical Formula: $ PPL(W) = \exp\left(-\frac{1}{N} \sum_{i=1}^N \log P(w_i | w_1, \dots, w_{i-1})\right) $
    • Symbol Explanation:
      • WW: A sequence of words w1,w2,,wNw_1, w_2, \dots, w_N.
      • NN: The total number of words in the sequence.
      • P(wiw1,,wi1)P(w_i | w_1, \dots, w_{i-1}): The probability of the ii-th word given all preceding words, as predicted by the language model.
      • log\log: Natural logarithm.
      • exp\exp: The exponential function.

5.3. Baselines

The MemoryFormer models are compared against two main categories of baselines:

  • Pythia Models: These are a suite of open-source large language models developed by EleutherAI, designed for analyzing LLMs across training and scaling. The authors use Pythia-70M, Pythia-160M, and Pythia-410M as direct baselines. MemoryFormer models (e.g., MF-tiny, MF-small, MF-base) are built upon these Pythia models, maintaining the same hidden size and number of layers for fair comparison. The Pythia framework also provides completely available datasets and detailed model hyper-parameters, which aids reproducibility.

  • Efficient Transformer Methods: To demonstrate the superiority of MemoryFormer over other FLOPs-reduction strategies, the MemoryFormer-base model is compared against Pythia-410M whose multi-head attention module has been replaced by:

    • Linformer [26]: An efficient Transformer that reduces the quadratic complexity of self-attention to linear complexity by projecting the key and value matrices to a lower dimension.

    • Cosformer [22]: A Transformer variant that rethinks the softmax in attention using cosine functions to achieve linear complexity.

    • Performer [8]: Another efficient Transformer that approximates the softmax attention kernel using positive orthogonal random features, also leading to linear complexity.

      All models are trained from scratch using the same optimizer, scheduler, and hyper-parameters as the Pythia settings, with a specific adjustment for MemoryFormer's learning rate due to sparse gradients. The only fully-connected layer remaining in MemoryFormer is the classifier head for the final output.

6. Results & Analysis

6.1. Core Results Analysis

The experimental results demonstrate that MemoryFormer significantly reduces FLOPs while maintaining or improving performance across various NLP benchmarks.

Evaluation Across Different Scales (Table 1): The paper compares Pythia baselines with MemoryFormer variants (MF-tiny, MF-small, MF-base) that have equivalent hidden sizes and number of layers. FLOPs are measured for one Transformer block with a sequence length of 2048.

The following are the results from Table 1 of the original paper:

Model Pythia-70M MF-tiny Pythia-160M MF-small Pythia-410M MF-base
Layers 6 6 12 12 24 24
Hidden Size 512 512 768 768 1024 1024
FLOPs w/o Attn. 6.4 G 0.4 G 14.5 G 1.0 G 25.8 G 1.6 G
Total FLOPs 10.7 G 4.7 G 20.9 G 7.4 G 34.4 G 10.2 G
PIQA 0.585 0.602 0.618 0.642 0.675 0.698
WinoGrande 0.511 0.522 0.497 0.523 0.534 0.546
WSC 0.365 0.375 0.365 0.394 0.471 0.385
ARC-E 0.380 0.437 0.440 0.461 0.517 0.585
ARC-C 0.177 0.228 0.201 0.247 0.202 0.259
LogiQA 0.232 0.260 0.210 0.272 0.209 0.272
Avg. 0.375 0.404 0.389 0.423 0.435 0.458

Analysis of Table 1:

  • FLOPs Reduction: MemoryFormer models consistently achieve drastically lower FLOPs compared to their Pythia counterparts. For instance, Pythia-70M has Total FLOPs of 10.7 G, while MF-tiny has 4.7 G (a 56% reduction). Pythia-410M has 34.4 G Total FLOPs, while MF-base has 10.2 G (a 70% reduction). The FLOPs w/o Attn. (which corresponds to the FC layers being replaced by Memory Layers) shows an even more dramatic reduction, from 6.4 G to 0.4 G for the smallest model, highlighting the effectiveness of the proposed Memory Layer.
  • Performance: Across all three model scales, MemoryFormer models generally achieve better average accuracy on the benchmarks than the baseline Pythia models. MF-tiny (0.404 Avg.) outperforms Pythia-70M (0.375 Avg.), MF-small (0.423 Avg.) outperforms Pythia-160M (0.389 Avg.), and MF-base (0.458 Avg.) outperforms Pythia-410M (0.435 Avg.). This is a significant finding: MemoryFormer not only reduces computation but also improves performance, suggesting that the Memory Layer can be a more effective or efficient feature transformation mechanism for these tasks.

Comparison with Efficient Transformers (Table 2): The paper compares MemoryFormer-base (based on Pythia-410M) with other efficient Transformer methods that primarily optimize attention.

The following are the results from Table 2 of the original paper:

Model Pythia-410M Linformer cosFormer Performer MemoryFormer-base
FLOPs 34.4 G 26.1G 30.0 G 26.7 10.2 G
PIQA 0.675 0.527 0.522 0.643 0.698
WinoGrande 0.534 0.511 0.506 0.496 0.546
WSC 0.471 0.635 0.605 0.433 0.385
ARC-E 0.517 0.265 0.267 0.470 0.585
ARC-C 0.202 0.244 0.263 0.231 0.259
LogiQA 0.209 0.207 0.264 0.236 0.272
Avg. 0.435 0.398 0.405 0.418 0.458

Analysis of Table 2:

  • FLOPs Comparison: MemoryFormer-base achieves by far the lowest FLOPs (10.2 G) compared to all other methods, including the original Pythia-410M (34.4 G) and other efficient attention models (Linformer 26.1 G, cosFormer 30.0 G, Performer 26.7 G). This strongly supports the paper's claim that FC layers are the dominant source of computation and their replacement offers the most significant FLOPs reduction.

  • Performance Comparison: While other efficient attention methods manage to reduce FLOPs, they generally suffer from a considerable performance degradation compared to the baseline Pythia-410M (e.g., Linformer Avg. 0.398 vs Pythia Avg. 0.435). In contrast, MemoryFormer-base not only achieves the lowest FLOPs but also the highest average accuracy (0.458), even surpassing the original Pythia-410M. This highlights a key advantage: MemoryFormer offers both efficiency and improved performance.

    These results validate the core hypothesis: by focusing on the FC layers and replacing them with Memory Layers, MemoryFormer provides a novel and highly effective solution for minimizing Transformer computation without compromising (and often enhancing) model performance.

6.2. Ablation Study

The paper conducts several ablation studies to understand the impact of key hyper-parameters and design choices on MemoryFormer's performance and efficiency.

Tradeoff between τ\tau and KK (Table 3): This study investigates how the choice of τ\tau (bit width of sub-vectors) and KK (number of hash tables) affects model performance (Val. PPL), FLOPs, and Memory Size. The hidden size is kept constant at d=512d=512.

The following are the results from Table 3 of the original paper:

d τ K Val. PPL↓ FLOPs Memory Size
512 4 128 19.01 0.14 G 2.1 MB
512 8 64 18.82 0.07 G 16.8 MB
510 10 51 18.67 0.06 G 53.5 MB

Analysis of Table 3:

  • As τ\tau increases (and KK decreases proportionally to keep dτKd \approx \tau K), the Val. PPL decreases, indicating improved performance (19.01 to 18.67). This is because a larger τ\tau means each hash table bucket has 2τ2^\tau entries, exponentially increasing the capacity and expressiveness of the hash tables to represent features.
  • However, Memory Size (storage required by the Memory Layer Q) increases drastically with τ\tau (2.1 MB to 53.5 MB). This is due to the 2τ2^\tau factor in the space complexity of each hash table.
  • FLOPs initially decrease and then stabilize. FLOPs for the Memory Layer are O(sdhτ)\mathcal{O}(\frac{s d h}{\tau}), so increasing τ\tau should decrease FLOPs. The change from 0.14 G to 0.07 G to 0.06 G demonstrates this.
  • The paper concludes that τ=8\tau=8 offers a good trade-off between efficiency and memory usage, balancing performance gains with manageable memory footprint.

Larger Learning Rate (Table 4): This study investigates the effect of the learning rate (LR) on MemoryFormer-tiny's performance, specifically due to the sparse nature of gradients for the hash tables. Val. PPL is reported after 8000 training steps.

The following are the results from Table 4 of the original paper:

LR Val. PPL ↓
1e-3 19.86
2e-3 19.07
3e-3 18.82
4e-3 18.84

Analysis of Table 4:

  • The baseline learning rate for Pythia-70M is 1e31 \mathrm{e}{-}3.
  • Increasing the LR from 1e31 \mathrm{e}{-}3 to 3e33 \mathrm{e}{-}3 (3 times the baseline) significantly improves performance, reducing Val. PPL from 19.86 to 18.82.
  • However, further increasing LR to 4e34 \mathrm{e}{-}3 slightly degrades performance (PPL increases to 18.84).
  • This confirms the authors' conjecture that a larger learning rate helps to compensate for the sparsity of gradients in the hash tables, as many buckets might not be updated in every training step. 3e33 \mathrm{e}{-}3 is chosen as the optimal LR.

Expanding Bits in the Memory Block (Table 5): This study explores the impact of dimensionality expansion within the Memory Block (analogous to the expansion factor in FFNs). MemoryFormer-tiny with d=512,τ=8,K=64d=512, \tau=8, K=64 serves as the baseline. τ\tau' denotes the bit-width of sub-vectors in the second Memory Layer after expansion.

The following are the results from Table 5 of the original paper:

#Expanding Bit τ′ Val. PPL ↓ Size of Hash Tables Memory Size
0 8 19.89 TM 1 R256×512 TM 2 R256×512 33.6 MB
1 9 19.26 k k R256×576' R512×512 52.4 MB
2 10 18.82 k E R256×640' TM 2 R1024×512 k 88.1 MB
3 11 18.54 R256×704' T M R2048×512 , k 157.3 MB

Analysis of Table 5:

  • As the number of expanding bits (and thus τ\tau') increases, the Val. PPL consistently decreases (from 19.89 to 18.54), indicating improved model capacity and performance. This is consistent with the idea of expanding dimensionality in FFNs to increase expressiveness.
  • However, the Memory Size required by the Memory Block (specifically, the hash tables for the second Memory Layer which scale with 2τ2^{\tau'}) increases exponentially (from 33.6 MB to 157.3 MB).
  • The authors choose 2 as the number of expanding bits (meaning τ=10\tau'=10) as a trade-off, balancing performance gains with the exponential growth in memory consumption. This choice leads to a Val. PPL of 18.82.

Removing Non-linearity in the Memory Block (Table 6): This study verifies the design choice of omitting activation functions (like GeLU) between the two Memory Layers in a Memory Block. A MemoryFormer-tiny is trained with and without an extra GeLU layer.

The following are the results from Table 6 of the original paper:

Model PIQA WinoGrande WSC ARC-E ARC-C LogiQA Avg.
MemoryFormer-tiny w/o GeLU 0.602 0.522 0.375 0.437 0.228 0.260 0.404
MemoryFormer-tiny w/ GeLU 0.605 0.521 0.375 0.435 0.229 0.260 0.404

Analysis of Table 6:

  • Adding an extra GeLU layer between the Memory Layers in the Memory Block results in nearly identical performance to the MemoryFormer without GeLU (average scores are both 0.404, with minor task-specific differences that are negligible).
  • This confirms the authors' hypothesis that the hashing operation itself provides sufficient non-linearity, making explicit activation functions redundant in the Memory Block. This simplifies the architecture without performance loss.

6.3. Visualization

Distribution of Hash Bucket (Figure 4): To ensure that the Memory Layer effectively utilizes its hash tables and provides diverse outputs, it's important that input sub-vectors zk\mathbf{z}_k are distributed somewhat uniformly across the available hash buckets. A highly skewed distribution where only a few buckets are frequently accessed would limit the model's capacity. The authors visualize the frequency at which each bucket is retrieved within selected hash tables from the first building block of a MemoryFormer-tiny model, using 2048 sequences of 1024 token length. They examine the first and last tables of the Q, K, V projection layers and the two Memory Layers in the FFN module.

Figure 4: The frequency at which each bucket in the hash table is retrieved. 该图像是一个图表,展示了不同哈希表桶的检索频率,用于分析 MemoryFormer 模型的性能。图中包含了各个模型在多个基准上的评分,以及不同内存层的检索频率分布情况。

Figure 4: The frequency at which each bucket in the hash table is retrieved. The visualization in Figure 4 shows that the number of times each bucket in the selected hash tables is retrieved by zk\mathbf{z}_k is generally uniform. This indicates that the design, particularly the Norm() layer before hashing and the softmax-like weighting p(zk)p(\mathbf{z}_k), successfully encourages an even distribution of hash bucket accesses. This uniform distribution is crucial for maximizing the effective capacity of the Memory Layer and ensuring that MemoryFormer can learn a rich set of feature transformations.

7. Conclusion & Reflections

7.1. Conclusion Summary

This paper introduces MemoryFormer, a novel Transformer architecture that addresses the escalating computational complexity of LLMs by fundamentally rethinking the role of fully-connected (FC) layers. Unlike previous efforts that primarily focused on optimizing multi-head attention (MHA), MemoryFormer targets the FC layers, which are identified as the dominant source of FLOPs in most practical scenarios.

The core innovation is the Memory Layer, which replaces traditional matrix multiplication with a memory retrieval process based on locality-sensitive hashing (LSH). This layer constructs in-memory lookup tables of discrete vectors and uses a differentiable hash algorithm to retrieve and aggregate correlated vectors based on the input embedding. This process effectively estimates the result of matrix multiplication with significantly fewer FLOPs.

The MemoryFormer can be trained end-to-end from scratch. Extensive experiments on multiple NLP benchmarks demonstrate that MemoryFormer achieves substantial reductions in FLOPs (e.g., up to 70% reduction in total FLOPs for MemoryFormer-base compared to Pythia-410M), while consistently delivering comparable or even superior performance. This work presents a new and effective paradigm for Transformer efficiency, shifting computation from FLOPs to memory access.

7.2. Limitations & Future Work

The paper implicitly discusses some challenges and implications, which can be interpreted as limitations or considerations for future work:

  • Memory-Computation Trade-off: While MemoryFormer significantly reduces FLOPs, it does so by increasing memory footprint (for the hash tables). The exponential growth of Memory Size with the bit width τ\tau (as seen in Table 3 and Table 5) is a critical trade-off that needs careful management. Choosing optimal τ\tau and KK values is essential to balance performance and memory constraints.
  • Sparsity of Gradients: The authors explicitly mention that gradients to the hash tables are sparse (Eq. (12)), meaning some buckets might not get updated in every training step. This required increasing the learning rate (ablation in Table 4) to ensure effective training. Further research could explore more sophisticated gradient accumulation or optimization strategies for sparse hash table updates.
  • Hardware Implications: The paper points out that MemoryFormer could provide "guiding significance for the hardware design (e.g. bigger bus width and higher cache hit rate) of the next-generation parallel-computing platform." This implies that current hardware, optimized for matrix multiplication on GPUs, might not be ideally suited for MemoryFormer's memory-centric operations. Fully realizing MemoryFormer's potential might require specialized hardware or memory architectures.
  • Generalizability to other modalities: While demonstrated on NLP tasks, the applicability and optimal configuration of MemoryFormer for other modalities like computer vision or speech recognition would need further investigation.

7.3. Personal Insights & Critique

MemoryFormer offers a fresh and impactful perspective on Transformer efficiency. Most Transformer optimization efforts have focused on making attention cheaper or matrix multiplication faster. By targeting the fully-connected layers and replacing computation with intelligent memory retrieval, the paper introduces a fundamentally different approach.

Key Strengths:

  • Novelty: The core idea of replacing FC layers with differentiable LSH-based Memory Layers is highly innovative. It shifts the problem from "how to compute matrix products faster" to "how to retrieve appropriate outputs from memory effectively."
  • Significant FLOPs Reduction: The experimental results clearly show a massive reduction in FLOPs, which is crucial for the sustainability and accessibility of LLMs.
  • Performance Improvement: The fact that MemoryFormer often improves performance while drastically cutting FLOPs is particularly impressive. This suggests that the Memory Layer is not just an approximation but potentially a more robust or regularizing form of feature transformation for these tasks.
  • Differentiability and End-to-End Training: The clever design of p(zk)p(\mathbf{z}_k) to enable gradient flow to the input vector x\mathbf{x} is critical for practical usability and allows for full end-to-end training, which is a major engineering and theoretical achievement for LSH-based methods.

Potential Issues/Areas for Improvement:

  • Actual Latency vs. FLOPs: While FLOPs are a good proxy for computation, they don't always directly translate to wall-clock latency. MemoryFormer might be memory-bound, meaning its performance could be limited by memory bandwidth and access times, especially with large hash tables and frequent lookups. Future work could benchmark actual inference speeds on various hardware platforms.

  • Memory Bandwidth Requirements: Shifting from FLOPs to memory access implies a greater reliance on memory bandwidth. While the hash tables are smaller than a full weight matrix, the frequent, scattered accesses across KK tables could strain memory bandwidth and potentially lead to cache misses, which are expensive. The trade-off is complex and depends heavily on hardware characteristics.

  • Temperature Parameter tt: The temperature tt in p(zk)p(\mathbf{z}_k) is a hyper-parameter that could significantly influence the gradient flow and exploration of the hash tables. Its tuning and sensitivity analysis could be explored further.

  • Interpretation of Memory Layer: While FC layers have a clear interpretation as linear transformations followed by non-linearity, the Memory Layer functions more like a content-addressable memory. Deeper theoretical analysis or visualizations could shed more light on what the Memory Layer learns and how it represents features compared to traditional FC layers.

  • Scalability of KK and τ\tau to Larger Models: The current ablation studies are on a MemoryFormer-tiny model. As models scale to hundreds of billions or trillions of parameters, finding the optimal KK and τ\tau becomes even more crucial, and the memory constraints might become more challenging.

    This paper is highly inspiring. It demonstrates that innovation in neural network architecture doesn't always have to be about new mathematical operations but can also involve rethinking how existing operations are performed and leveraging different hardware resources. The concept of memory-centric computation could be applied to other deep learning components beyond Transformers and FC layers, potentially leading to a new wave of efficient AI models that are less dependent on raw FLOPs and more on intelligent data management.

Similar papers

Recommended via semantic vector search.

No similar papers found yet.