论文状态:已完成

MemoryFormer : Minimize Transformer Computation by Removing Fully-Connected Layers

发表:2024/11/06
原文链接PDF 下载
价格:0.100000
价格:0.100000
已有 3 人读过
本分析由 AI 生成,可能不完全准确,请以原文为准。

TL;DR 精炼摘要

为了降低大型语言模型的计算复杂度,本文提出了一种新颖的Transformer架构MemoryFormer,通过移除大部分全连接层的计算,仅保留多头注意力操作所需的必要计算。利用内存查找表和哈希算法动态检索离散向量,从而显著减少了计算成本,并在多个基准测试中验证了模型的有效性。

摘要

In order to reduce the computational complexity of large language models, great efforts have been made to to improve the efficiency of transformer models such as linear attention and flash-attention. However, the model size and corresponding computational complexity are constantly scaled up in pursuit of higher performance. In this work, we present MemoryFormer, a novel transformer architecture which significantly reduces the computational complexity (FLOPs) from a new perspective. We eliminate nearly all the computations of the transformer model except for the necessary computation required by the multi-head attention operation. This is made possible by utilizing an alternative method for feature transformation to replace the linear projection of fully-connected layers. Specifically, we first construct a group of in-memory lookup tables that store a large amount of discrete vectors to replace the weight matrix used in linear projection. We then use a hash algorithm to retrieve a correlated subset of vectors dynamically based on the input embedding. The retrieved vectors combined together will form the output embedding, which provides an estimation of the result of matrix multiplication operation in a fully-connected layer. Compared to conducting matrix multiplication, retrieving data blocks from memory is a much cheaper operation which requires little computations. We train MemoryFormer from scratch and conduct extensive experiments on various benchmarks to demonstrate the effectiveness of the proposed model.

思维导图

论文精读

中文精读

1. 论文基本信息

1.1. 标题

MemoryFormer : Minimize Transformer Computation by Removing Fully-Connected Layers (MemoryFormer:通过移除全连接层最小化 Transformer 计算量)

1.2. 作者

Ning Ding, Yehui Tang, Haochen Qin, Zhenli Zhou, Chao Xu, Lin Li, Kai Han, Heng Liao, Yunhe Wang。 作者来自多个机构,包括:

  • State Key Lab of General AI, School of Intelligence Science and Technology, Peking University (北京大学通用人工智能国家重点实验室,智能科学与技术学院)
  • Huawei Noah's Ark Lab (华为诺亚方舟实验室)
  • Huawei HiSilicon (华为海思)

1.3. 发表期刊/会议

该论文发布于 OpenReview,发布日期为 2024-11-06T00:00:00.000Z。OpenReview 是一个在机器学习社区中广泛使用的预印本平台,其内容通常在正式会议或期刊发表前进行公开评审。

1.4. 发表年份

2024年

1.5. 摘要

为了降低大型语言模型 (LLMs) 的计算复杂度,研究人员在提高 Transformer 模型效率方面投入了大量精力,例如 线性注意力 (linear attention)FlashAttention。然而,为了追求更高的性能,模型规模和相应的计算复杂度仍在不断扩大。本文提出了 MemoryFormer,这是一种新颖的 Transformer 架构,它通过一个全新的视角显著降低了计算复杂度 (FLOPs)。除了 多头注意力 (multi-head attention) 操作所需的必要计算外,MemoryFormer 几乎消除了 Transformer 模型中的所有计算。这通过利用一种替代的特征转换方法来实现,该方法取代了 全连接层 (fully-connected layers) 中的线性投影。具体来说,首先构建了一组内存中的 查找表 (lookup tables),这些表存储了大量离散向量,以取代线性投影中使用的权重矩阵。然后,使用 哈希算法 (hash algorithm) 根据输入嵌入 (embedding) 动态检索出相关的向量子集。检索到的向量组合在一起形成输出嵌入,这提供了 全连接层 中矩阵乘法操作结果的估计。与执行矩阵乘法相比,从内存中检索数据块是一种计算成本低得多的操作,几乎不需要计算。本文从头开始训练 MemoryFormer,并在各种基准测试上进行了广泛实验,以证明所提出模型的有效性。

1.6. 原文链接

https://openreview.net/pdf?id=04EC4ZnZJj

2. 整体概括

2.1. 研究背景与动机

Transformer 模型在深度学习领域取得了显著成就,特别是在 自然语言处理 (NLP) 领域引发了革命,并推动了计算机视觉和语音识别等其他领域的模型架构创新。近年来,大型语言模型 (LLMs) 作为规模巨大的 Transformer 模型,受到了广泛关注。它们展现出的前所未有的 涌现能力 (emergent abilities)通用人工智能 (artificial general intelligence) 指明了潜在路径。

然而,模型规模的扩大不仅带来了更高的智能,也带来了巨大的计算资源消耗。日益增长的计算复杂度是当前阻碍 LLMs 应用和普及的主要障碍。研究社区为此付出了巨大努力,通过优化 Transformer 模型架构来提高效率。现有方法主要分为两类:

  1. 传统方法:模型剪枝 (model pruning)权重量化 (weight quantization),能在一定程度上降低 LLMs 的计算复杂度。

  2. Transformer 特有方法: 重新设计了作为序列建模关键的 自注意力机制 (self-attention mechanism)。这些方法通过使用 滑动窗口 (sliding-windows)核函数 (kernel function) 将计算复杂度从序列长度的二次方降低到亚二次方甚至线性,同时保持可比的性能。

    本文作者观察到以下两点,并以此作为主要动机:

  • 计算瓶颈: 在大多数应用场景中,多头注意力 (Multi-Head Attention, MHA) 操作只占计算复杂度的很小一部分,而 Transformer 模型中绝大部分计算量来自 全连接层 (Fully-Connected, FC)。具体而言,对于隐藏维度为 dd、序列长度为 ss 的标准 Transformer 模型,MHA 操作的浮点计算量是 2s2d2s^2d,而所有 FC 层 的计算量是 12sd212sd^2。只有当 s>6ds > 6d 时,MHA 的计算量才会占据主导地位。对于隐藏维度 d=4096d=4096LLM,这表示序列长度 ss 需要大于 24K。这表明 FC 层 在大多数实际场景中是更大的计算瓶颈。

  • 硬件资源利用不均衡: 目前深度神经网络的推理阶段主要依赖 图形处理单元 (GPU) 的并行计算核心,而计算机系统中的 中央处理器 (CPU)随机存取存储器 (RAM) 资源(通常达到太字节级别,CPU 核心数量众多)却几乎未被充分利用。例如,NVIDIA DGX A100 拥有 2TB 的 RAM 和 128 个 CPU 核心。此外,CPU 制造商也开始开发 张量核心 (tensor core) 以加速并行计算,这可能使得低延迟的 CPU 推理在未来变得可行。

    基于上述观察,本文旨在提出一种新的 Transformer 架构,即 MemoryFormer,从一个全新的角度最小化所需的计算复杂度,通过用内存操作替代计算密集型的 全连接层

2.2. 核心贡献/主要发现

本文最主要的贡献在于提出了一种新颖的 Transformer 架构 MemoryFormer,它通过用基于内存查找的 Memory Layer 替代传统的 全连接层,从根本上降低了模型的计算复杂度。具体的核心贡献和发现包括:

  • 提出 MemoryFormer 架构: 引入了 Memory Layer 来替代 Transformer 中的 全连接层,包括 MHA 前的线性投影和 前馈网络 (FFN) 中的 FC 层
  • 基于 Locality-Sensitive Hashing (LSH) 的特征转换: Memory Layer 通过构建一组内存中的 查找表 来存储大量离散向量。它使用一种简化的 LSH 算法(基于输入向量的符号化)动态检索相关的向量子集,并将这些向量加权聚合以形成输出嵌入,从而近似 矩阵乘法 (matrix multiplication) 的结果。
  • 可学习的内存表: 设计了一种方法,使得存储在 查找表 中的向量可以通过 反向传播 (back-propagation) 学习,从而实现端到端训练 MemoryFormer
  • 显著降低 FLOPs MemoryFormer 几乎消除了 Transformer 模型中除 自注意力 (self-attention) 操作外的所有计算。与传统 Transformer 相比,当序列长度 s=2048s=2048 和隐藏维度 d=2048d=2048 时,一个 MemoryFormer 块所需的 FLOPs 仅为基线 Transformer 块的约 19%。随着模型规模的扩大,FLOPs 减少效果将更加显著。
  • 性能与效率的权衡: 实验结果表明,MemoryFormer 在大幅减少计算量的同时,在多个 NLP 基准测试上取得了与基线 Transformer 模型相当甚至更好的平均准确率。这表明该方法能够在不牺牲性能的情况下提高效率。
  • 硬件设计指导意义: 本工作不仅提出了一种新的 FLOPs 减少策略,也为下一代并行计算平台的硬件设计(例如,更大的总线宽度和更高的缓存命中率)提供了指导意义。
  • 超越现有高效 Transformer 方法:LinformerCosformerPerformer 等主要关注优化 自注意力机制 的高效 Transformer 方法相比,MemoryFormer 通过解决 FC 层 的计算瓶颈,在 FLOPs 减少和性能之间取得了更优的平衡。

3. 预备知识与相关工作

3.1. 基础概念

3.1.1. Transformer 模型 (Transformer Model)

Transformer 是一种基于 自注意力机制 (self-attention mechanism) 的深度学习模型架构,由 Vaswani 等人于 2017 年提出。它彻底改变了 自然语言处理 (NLP) 领域,并被广泛应用于 大型语言模型 (LLMs)、计算机视觉和语音识别等任务。Transformer 模型的核心在于其能够并行处理序列数据,而不是像 循环神经网络 (RNN) 那样顺序处理,这极大地提高了训练效率和模型处理长距离依赖的能力。

Transformer 的基本结构由编码器 (Encoder) 和解码器 (Decoder) 组成,每个都包含多个相同的层。每个层又主要由两个子层构成:

  • 多头注意力 (Multi-Head Attention, MHA): 这是 Transformer 的核心组件。它允许模型同时关注输入序列的不同部分,并提取各种类型的信息。MHA 机制通过将输入查询 (Query, QQ)、键 (Key, KK) 和值 (Value, VV) 投影到不同的子空间,然后并行执行多个 注意力头 (attention heads),最后将它们的输出拼接 (concatenate) 起来。 缩放点积注意力 (Scaled Dot-Product Attention) 的计算公式如下: Attention(Q,K,V)=softmax(QKTdk)V \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 其中,QRs×dkQ \in \mathbb{R}^{s \times d_k} 是查询矩阵,KRs×dkK \in \mathbb{R}^{s \times d_k} 是键矩阵,VRs×dvV \in \mathbb{R}^{s \times d_v} 是值矩阵,ss 是序列长度,dkd_k 是键和查询的维度,dvd_v 是值的维度。dk\sqrt{d_k} 用于缩放点积结果,以防止在 dkd_k 很大时 softmax 函数的梯度过小。

  • 前馈网络 (Feed-Forward Network, FFN): 这是一个简单的 全连接层 (Fully-Connected Layer, FC),它对 MHA 的输出进行非线性变换。FFN 通常由两个线性变换和一个 激活函数 (activation function) (如 ReLUGELU) 组成。在标准 Transformer 中,FFN 通常将隐藏维度 dd 扩展到 4d,然后再映射回 dd

    这两个子层都伴随着 残差连接 (residual connection)层归一化 (layer normalization)

3.1.2. 全连接层 (Fully-Connected Layer, FC)

在神经网络中,全连接层(也称为 密集层 (dense layer)线性层 (linear layer))是一种最基本的层。它将前一层的所有神经元的输出作为输入,并通过学习到的权重矩阵和偏置向量进行线性变换,然后通常会通过一个 激活函数。 对于一个输入向量 xRd\mathbf{x} \in \mathbb{R}^d全连接层 的输出 yRh\mathbf{y} \in \mathbb{R}^h 可以表示为: y=xW+b \mathbf{y} = \mathbf{x}\mathbf{W} + \mathbf{b} 其中,WRd×h\mathbf{W} \in \mathbb{R}^{d \times h} 是权重矩阵,bRh\mathbf{b} \in \mathbb{R}^h 是偏置向量。 在 Transformer 模型中,全连接层 广泛用于:

  • MHA 机制中将 Q, K, V 投影到不同子空间。

  • FFN 中的两个线性变换。

  • 模型输出层(如分类头)。

    全连接层 的主要计算开销来自于 矩阵乘法 xW\mathbf{x}\mathbf{W}。对于输入序列 XRs×d\mathbf{X} \in \mathbb{R}^{s \times d} 和权重矩阵 WRd×h\mathbf{W} \in \mathbb{R}^{d \times h},矩阵乘法 Y=XW\mathbf{Y} = \mathbf{X}\mathbf{W} 的计算复杂度为 O(sdh)\mathcal{O}(sdh)

3.1.3. 浮点运算 (Floating Point Operations, FLOPs)

FLOPs 是衡量模型计算复杂度的指标,指模型执行的浮点运算次数。通常用于评估模型的计算效率和推理速度。较低的 FLOPs 意味着更快的计算速度和更少的能源消耗。

3.1.4. 局部敏感哈希 (Locality-Sensitive Hashing, LSH)

局部敏感哈希 (LSH) 是一种将高维数据点映射到低维哈希值的技术,其核心思想是使相似的数据点以高概率映射到相同的哈希桶 (hash bucket),而不相似的数据点以低概率映射到相同的哈希桶。与传统的哈希函数旨在最小化冲突不同,LSH 故意利用冲突来识别相似项。 LSH 通常用于大规模数据集中的相似性搜索、聚类和近邻查找等任务。例如,在文本或图像检索系统中,LSH 可以将语义相似的文本或图像哈希到同一个桶中,从而提高检索效率。

3.2. 前人工作

3.2.1. 传统模型优化方法

  • 模型剪枝 (Model Pruning): 通过移除模型中不重要或冗余的权重、神经元或层来减小模型大小和计算量。例如,Optimal Brain Damage
  • 权重量化 (Weight Quantization): 将模型的浮点权重转换为低位宽的定点表示(如 8 比特整数),从而减少模型大小、内存占用和计算需求。例如,Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference

3.2.2. 高效自注意力机制 (Efficient Self-Attention Mechanisms)

这些方法主要关注降低 TransformerMHA 操作的二次方计算复杂度 O(s2d)\mathcal{O}(s^2d),使其变为亚二次方、线性或更高效:

  • Reformer [17]: 结合了 LSH 来近似 注意力 计算,以降低序列长度的二次方依赖。它将相似的 QueryKey 分组到相同的桶中,只对桶内的元素计算 注意力
  • YOSO [28]: 提出了一种基于伯努利采样的线性成本 自注意力 机制,通过只对部分 Key 进行采样来减少计算量。
  • Linformer [26]: 提出对 KeyValue 矩阵进行低秩分解,从而将 注意力 复杂度从二次方降低到线性。
  • CosFormer [22]: 通过用余弦函数替换 softmax 函数,并利用其性质来线性化 注意力 计算。
  • Performer [8]: 使用 正交随机特征 (orthogonal random features) 来近似 softmax 注意力,使其计算复杂度变为线性。
  • 滑动窗口 (Sliding Window) [14]: 限制每个 Token 只关注其局部窗口内的其他 Token,从而将注意力计算限制在一个局部范围内,降低全局计算量。
  • FlashAttention [10]: 是一种 I/O感知型 (I/O-aware) 的精确 注意力 算法,通过优化 GPU 内存访问模式(减少 HBM 读写,最大化 SRAM 使用),显著加速 自注意力 计算并减少内存占用,但其理论计算复杂度仍然是二次方的。

3.2.3. 基于 FC 层 稀疏性优化

一些研究利用 FFN (MLP) 模块中间激活的稀疏性来减少计算:

  • [13][20] 等工作探索了 FFN 模块中激活的稀疏性,以实现计算优化。
  • [27] 提出了 HiRE,利用高召回率近似 Top-K 估计来提高 LLM 推理的效率。
  • LookupFFN [29] 利用 LSH 来加速 前馈网络 的推理速度。

3.3. 技术演进与差异化分析

3.3.1. 技术演进

Transformer 模型的高效化研究主要沿着两个方向发展:一是优化 自注意力机制,二是优化 前馈网络 (FFN)。早期的工作大多集中在 自注意力,因为其二次方的计算复杂度在长序列上是显而易见的瓶颈。从 Reformer 引入 LSHLinformerPerformerCosFormer 等提出各种线性 注意力 变体,再到 FlashAttention 优化 硬件 I/O,这些都旨在解决 MHA 的效率问题。

然而,随着模型规模的不断扩大和实际应用中序列长度并非总是极端长(如本文分析的 s>6ds > 6d 的情况),全连接层 的计算开销逐渐凸显。一些工作开始关注 FFN 的优化,例如利用稀疏性或 LSH (LookupFFN)。

3.3.2. 差异化分析

本文提出的 MemoryFormer 与现有工作的主要区别和创新点在于其解决计算瓶颈的视角:

  • 关注点不同:
    • 现有高效 Transformer 方法: 大多数(如 LinformerCosformerPerformer 等)致力于优化 多头注意力 (MHA) 操作,将其复杂度从二次方降至线性或亚二次方。FlashAttention 则通过硬件优化提高 MHA 的实际运行效率。
    • MemoryFormer 专注于移除或替换 全连接层 (FC) 的计算。作者观察到,在大多数实际场景中,FC 层 的计算量是 Transformer 模型中最大的瓶颈(除非序列长度 ss 远大于隐藏维度 dd)。
  • 方法论不同:
    • 现有 MHA 优化: 通常通过改变 注意力 的计算方式(如 核函数低秩近似局部窗口LSH 采样)或优化 硬件 I/O 来实现。
    • MemoryFormer 提出用基于内存查找的 Memory Layer近似 全连接层矩阵乘法。它利用 LSH 将输入嵌入映射到内存中的离散向量,通过检索和聚合这些向量来产生输出。这种方法将计算密集型的 矩阵乘法 转换为内存访问和简单的加权求和,从而大幅降低了 FLOPs
  • 资源利用不同:
    • 现有方法: 主要关注 GPU 内部的计算和内存优化。

    • MemoryFormer 旨在利用通常未被充分利用的 CPURAM 资源,通过 内存查找表 来分担 GPU 的计算压力,从而实现计算效率的提升。这还为未来的 硬件设计 提供了新的思路。

      简而言之,虽然 ReformerLookupFFN 也使用了 LSH,但 Reformer 是用于优化 注意力,而 LookupFFN 虽然优化 FFN,但 MemoryFormerMemory Layer 设计更加深入,它不仅仅是加速查找,而是通过替代性的内存操作来完全规避了 全连接层矩阵乘法MemoryFormer 提供了解决 Transformer 效率问题的一个全新且互补的视角,与之前主要关注 MHA 的工作形成对比。

4. 方法论

本文的核心思想是提出 MemoryFormer 架构,通过用 Memory Layer 替代 Transformer 模型中的所有 全连接层 (FC layers),从而显著降低计算复杂度。Memory Layer 利用 局部敏感哈希 (LSH) 和内存中的 查找表 (lookup tables) 来近似 FC 层矩阵乘法 操作。

4.1. 方法原理

在标准的 Transformer 模型中,多头注意力 (MHA)全连接层 (FC layers) 是两种主要的特征转换操作。FC 层,无论是在 MHA 前的投影(Q, K, V 投影)还是在 前馈网络 (FFN) 中,都通过 矩阵乘法 实现特征的线性投影。 对于一个输入 token 嵌入 xRd\mathbf{x} \in \mathbb{R}^d,一个由权重矩阵 WRd×h\mathbf{W} \in \mathbb{R}^{d \times h} 参数化的 FC 层 会计算输出 y=xW\mathbf{y} = \mathbf{x}\mathbf{W},其中 yRh\mathbf{y} \in \mathbb{R}^h 是输出 token 嵌入。对于长度为 ss 的序列,这成为矩阵乘法 Y=XW\mathbf{Y} = \mathbf{X}\mathbf{W},计算复杂度为 O(sdh)\mathcal{O}(sdh)

作者观察到 FC 层 是一个连续线性算子:对于相邻的输入特征向量 x1\mathbf{x}_1x2\mathbf{x}_2,其投影结果 y1=x1W\mathbf{y}_1 = \mathbf{x}_1\mathbf{W}y2=x2W\mathbf{y}_2 = \mathbf{x}_2\mathbf{W} 也很可能是相似的。基于这一特性,MemoryFormer 旨在寻找一种替代的映射函数,其计算复杂度远低于 O(sdh)\mathcal{O}(sdh),但仍能大致符合线性投影的性质。如果能实现这一点,就可以用这种替代方法替换所有 FC 层,从而减少计算量。

4.2. 核心方法详解

4.2.1. Compute-less Locality-Sensitive Hashing

传统的 LSH 函数旨在将相似的项映射到相同的哈希桶。MemoryFormer 在嵌入空间中应用 LSH 来编码输入特征向量。其核心思想是,如果输入向量 x\mathbf{x}LSH 函数哈希到一个存储向量 y^\hat{\mathbf{y}} 的特定桶中,那么 x\mathbf{x} 的相邻向量也应该被哈希到相同或相似的桶,并检索到相似的 y^\hat{\mathbf{y}}。如果这个检索到的 y^\hat{\mathbf{y}} 是对线性操作 y=xW\mathbf{y} = \mathbf{x}\mathbf{W} 结果的近似,那么就可以用这种哈希查找的方法来替换 全连接层,而只消耗极少的 FLOPs 用于哈希操作。

1. 简化的 LSH 函数与哈希表构建 首先,构建一个内存中的哈希表 TR2d×h\mathbf{T} \in \mathbb{R}^{2^d \times h},存储 2d2^d 个向量 [T]iRh[\mathbf{T}]_i \in \mathbb{R}^h。为了避免不必要的计算,本文使用一个非常简单的 LSH 函数来生成哈希码 h(x)h(\mathbf{x})。 具体过程如下: 输入向量 xRd\mathbf{x} \in \mathbb{R}^d 经过 sign 函数进行二值化,生成一个二进制表示(哈希码)s{1,1}d\mathbf{s} \in \{-1, 1\}^d。然后,integer 函数将这个二进制表示转换为一个非负整数,作为哈希表的索引。 数学公式如下: h(x)=integer(sign(x))sign([x]i)={1,if[x]i<0,1,if[x]i0,integer(s)=i=0d1[s]i+122iy^=[T]h(x) \begin{array}{l} h ( \mathbf { x } ) = \mathrm { i n t e g e r } ( \mathrm { s i g n } ( \mathbf { x } ) ) \\ \mathrm { s i g n } ( [ \mathbf { x } ] _ { i } ) = \left\{ \begin{array} { l } { { \displaystyle - 1 , \mathrm { i f } [ \mathbf { x } ] _ { i } < 0 , } } \\ { { \displaystyle 1 , \mathrm { i f } [ \mathbf { x } ] _ { i } \ge 0 , } } \end{array} \right. \\ \mathrm { i n t e g e r } ( \mathbf { s } ) = \sum _ { i = 0 } ^ { d - 1 } \frac { [ \mathbf { s } ] _ { i } + 1 } { 2 } \cdot 2 ^ { i } \\ \hat { \mathbf { y } } = [ \mathbf { T } ] _ { h ( \mathbf { x } ) } \end{array} 其中:

  • h(x)h(\mathbf{x}):将输入向量 x\mathbf{x} 映射到的哈希索引。
  • sign([x]i)\mathrm{sign}([\mathbf{x}]_i):对输入向量 x\mathbf{x} 的每个分量 [x]i[ \mathbf{x} ] _ { i } 进行符号判断,小于 0 则为 -1,否则为 1。这生成了一个由 -1 和 1 组成的向量 s\mathbf{s}
  • integer(s)\mathrm{integer}(\mathbf{s}):将 s\mathbf{s} 转换为哈希表的索引。具体方法是,将 s\mathbf{s} 的每个分量 [s]i[ \mathbf { s } ] _ { i } 映射到 0 或 1 (通过 [s]i+12\frac{ [ \mathbf { s } ] _ { i } + 1 } { 2 }),然后将其视为一个二进制数的位,并计算其对应的整数值。
  • y^\hat{\mathbf{y}}:从哈希表 T\mathbf{T} 中检索到的输出向量,索引为 h(x)h(\mathbf{x})

2. 解决哈希表内存爆炸问题 上述哈希表的空间复杂度为 O(2dh)\mathcal{O}(2^d h)。当 d=512d=512 (隐藏维度) 且使用 float16 数据类型时,所需的内存空间将达到约 1014510^{145} TB,这在实际中是不可行的。

为了解决这个问题,本文提出将输入向量 xRd\mathbf{x} \in \mathbb{R}^d 均匀地分成 KK 个不重叠的块: zk=split(x,num_chunk=K),k=1,2,,K \mathbf { z } _ { k } = \mathrm { s p l i t } ( \mathbf { x } , \mathrm { n u m \_ c h u n k } = K ) , k = 1 , 2 , \cdots , K 其中,zkRτ\mathbf{z}_k \in \mathbb{R}^{\tau},且 τ=dK\tau = \frac{d}{K},其中 dd 必须能被 KK 整除。 对于每个子向量 zk\mathbf{z}_k,都对应一个独立的哈希表 TkR2τ×h\mathbf{T}_k \in \mathbb{R}^{2^\tau \times h}。 因此,最终的输出结果 y^\hat{\mathbf{y}} 是所有 KK 个子向量对应哈希表检索结果的简单求和: y^=k=1K[Tk]h(zk) \hat { \mathbf { y } } = \sum _ { k = 1 } ^ { K } \left[ \mathbf { T } _ { k } \right] _ { h ( \mathbf { z } _ { k } ) } 通过这种分块策略,哈希表的空间复杂度降低为 O(K2τh)\mathcal{O}(K 2^\tau h)。例如,当 d=h=512,τ=8,K=64d=h=512, \tau=8, K=64 且使用 float16 数据类型时,所有 KK 个哈希表所需的总存储空间约为 16 MB,这大大降低了内存需求。

4.2.2. Memory Layer

上述公式 (7) 能够模拟 全连接层 的前向传播,并且哈希表中存储的值可以通过 反向传播 进行更新。然而,输入向量 x\mathbf{x} 无法直接获得梯度,因为它被哈希为多个整数 h(zk)h(\mathbf{z}_k) 作为索引进行检索,这是一个不可微分的操作。

为了解决梯度传播问题并引入对输入向量的依赖,本文将公式 (7) 重新表述为: y=k=1Kp(zk)[Tk]h(zk) \mathbf { y } = \sum _ { k = 1 } ^ { K } p ( \mathbf { z } _ { k } ) \cdot [ \mathbf { T } _ { k } ] _ { h ( \mathbf { z } _ { k } ) } 这里引入了一个系数 p(zk)p(\mathbf{z}_k) 来加权每个检索到的项,其中 p(zk)p(\mathbf{z}_k) 是变量 zk\mathbf{z}_k 的函数。

1. 相似度量与概率计算 Memory Layer 观察到,即使符号相同,许多子向量 zk\mathbf{z}_k 仍然可以哈希到同一个桶中,但这些子向量与桶的代表性二进制向量(每个分量为 1 或 -1)之间的角度是不同的,这可以用余弦值 cos(zk,sign(zk))\cos(\mathbf{z}_k, \mathrm{sign}(\mathbf{z}_k)) 定义。为了衡量 zk\mathbf{z}_k 与其对应的哈希桶 h(zk)h(\mathbf{z}_k) 之间的相关性,本文使用一个考虑方向和幅度的 缩放余弦相似度 (scaled cosine similarity)sim(zk,h(zk))=zk2sign(zk)2cos(zk,sign(zk))=zk,sign(zk) \mathrm { s i m } ( \mathbf { z } _ { k } , h ( \mathbf { z } _ { k } ) ) = \| \mathbf { z } _ { k } \| _ { 2 } \cdot \| \mathrm { s i g n } ( \mathbf { z } _ { k } ) \| _ { 2 } \cdot \mathrm { c o s } ( \mathbf { z } _ { k } , \mathrm { s i g n } ( \mathbf { z } _ { k } ) ) = \langle \mathbf { z } _ { k } , \mathrm { s i g n } ( \mathbf { z } _ { k } ) \rangle 其中,,\langle \cdot, \cdot \rangle 表示两个向量的内积。

然后,考虑到 zk\mathbf{z}_k 在查找表 Tk\mathbf{T}_k 中可能哈希到的所有 2τ2^\tau 个桶,本文定义 zk\mathbf{z}_k 特定映射到第 h(zk)h(\mathbf{z}_k) 个哈希桶的概率为: p(zk)=exp[sim(zk,h(zk))/t]i=02τ1exp[sim(zk,i)/t]=exp[zk,sign(zk)/t]i=02τ1exp[zk,integerτ1(i)/t] p ( \mathbf { z } _ { k } ) = \frac { \exp [ \mathrm { sim } ( \mathbf { z } _ { k } , h ( \mathbf { z } _ { k } ) ) / t ] } { \sum _ { i = 0 } ^ { 2 ^ { \tau } - 1 } \exp [ \mathrm { sim } ( \mathbf { z } _ { k } , i ) / t ] } = \frac { \exp [ \langle \mathbf { z } _ { k } , \mathrm { s i g n } ( \mathbf { z } _ { k } ) \rangle / t ] } { \sum _ { i = 0 } ^ { 2 ^ { \tau } - 1 } \exp [ \langle \mathbf { z } _ { k } , \mathrm { i n t e g e r } _ { \tau } ^ { - 1 } ( i ) \rangle / t ] } 其中:

  • tt 是温度 (temperature) 超参数。

  • integerτ1(i){1,1}τ\mathrm{integer}_{\tau}^{-1}(i) \in \{-1, 1\}^\tau 是一个函数,将非负整数 0i<2τ0 \leq i < 2^\tau 映射到其对应的 τ\tau 比特二进制表示(其中 0 映射为 -1,1 映射为 1)。

  • ,integerτ1(i)\langle \cdot , \mathrm { i n t e g e r } _ { \tau } ^ { - 1 } ( i ) \rangle 运算符表示对 τ\tau 维向量进行元素级选择性符号翻转后的求和。

    通过推导,可以得到更简化的形式: zk,sign(zk)=i=0τ1[zk]i,i=0τ1exp[zk,integerτ1(i)]=i=0τ1[exp([zk]i)+exp([zk]i)],p(zk)=exp(i=0τ1[zk]i/t)i=0τ1[exp([zk]i/t)+exp([zk]i/t)]=1i=0τ1[1+exp(2[zk]i/t)]. \begin{array} { r l r } { \langle \mathbf { z } _ { k } , \mathrm { s i g n } ( \mathbf { z } _ { k } ) \rangle = \sum _ { i = 0 } ^ { \tau - 1 } | [ \mathbf { z } _ { k } ] _ { i } | , } \\ & { \sum _ { i = 0 } ^ { \tau - 1 } \exp [ \langle \mathbf { z } _ { k } , \mathrm { i n t e g e r } _ { \tau } ^ { - 1 } ( i ) \rangle ] = \prod _ { i = 0 } ^ { \tau - 1 } [ \exp ( [ \mathbf { z } _ { k } ] _ { i } ) + \exp ( - [ \mathbf { z } _ { k } ] _ { i } ) ] , } \\ & { p ( \mathbf { z } _ { k } ) = \frac { \exp ( \sum _ { i = 0 } ^ { \tau - 1 } | [ \mathbf { z } _ { k } ] _ { i } | / t ) } { \prod _ { i = 0 } ^ { \tau - 1 } [ \exp ( [ \mathbf { z } _ { k } ] _ { i } / t ) + \exp ( - [ \mathbf { z } _ { k } ] _ { i } / t ) ] } = \frac { 1 } { \prod _ { i = 0 } ^ { \tau - 1 } [ 1 + \exp ( - 2 | [ \mathbf { z } _ { k } ] _ { i } | / t ) ] } . } \end{array} 因此,Memory Layer 的最终公式为: y=k=1Kp(zk)[Tk]h(zk)=k=1K[Tk]h(zk)i=0τ1[1+exp(2[zk]i/t)] \mathbf { y } = \sum _ { k = 1 } ^ { K } p ( \mathbf { z } _ { k } ) \cdot [ \mathbf { T } _ { k } ] _ { h ( \mathbf { z } _ { k } ) } = \sum _ { k = 1 } ^ { K } { \frac { [ \mathbf { T } _ { k } ] _ { h ( \mathbf { z } _ { k } ) } } { \prod _ { i = 0 } ^ { \tau - 1 } [ 1 + \exp ( - 2 | [ \mathbf { z } _ { k } ] _ { i } | / t ) ] } }

2. 梯度反向传播 利用上述公式 (11),可以计算损失函数 LL 对哈希表和输入向量的导数,从而实现端到端训练: L[Tk]i={p(zk)Ly, if h(zk)=i,0, if h(zk)i, i{0,1,,2τ1},Lx=concat([.Ly[Tk]h(zk)p(zk)zkforkinrange(1,K+1)]). \begin{array} { r l } & { \displaystyle \frac { \partial L } { \partial [ \mathbf { T } _ { k } ] _ { i } } = \left\{ \begin{array} { c } { p ( \mathbf { z } _ { k } ) \frac { \partial L } { \partial \mathbf { y } } , \mathrm { ~ i f ~ } h ( \mathbf { z } _ { k } ) = i , } \\ { 0 , \mathrm { ~ i f ~ } h ( \mathbf { z } _ { k } ) \neq i , } \end{array} \right. \ i \in \{ 0 , 1 , \cdots , 2 ^ { \tau } - 1 \} , } \\ & { \displaystyle \frac { \partial L } { \partial \mathbf { x } } = \mathrm { c o n c a t } ( \left[ \mathbf { \cdot } \mathbf { . } \cdot \frac { \partial L } { \partial \mathbf { y } } [ \mathbf { T } _ { k } ] _ { h ( \mathbf { z } _ { k } ) } ^ { \top } \frac { \partial p ( \mathbf { z } _ { k } ) } { \partial \mathbf { z } _ { k } } \dots \mathrm { f o r } k \mathrm { i n } \mathrm { r a n g e } ( 1 , K + 1 ) \right] ) . } \end{array} 这使得 Memory Layer 中的哈希表项和输入向量都能够通过 反向传播 进行学习和更新。

4.2.3. MemoryFormer 架构

MemoryFormer 遵循标准 Transformer 架构的通用设计范式,通过堆叠 NN 个块构建。下图(原文 Figure 3)展示了一个 MemoryFormer 构建块:

Figure 3: Left: The schematic diagram of the Memory Layer. Right: One building block of the MemoryFormer. 该图像是示意图,左侧展示了memory layer的结构,而右侧展示了MemoryFormer的一个模块。左侧的计算流程包括三个处理单元,输入和输出经过加权和计算,并结合记忆块的操作,相关公式为 σ(QKT)Vσ(QK^T)V

4.2.3.1. 多头注意力 (Multi-Head Attention)

  • 输入处理: 给定输入序列 X=(x1,x2,,xs)Rs×d\mathbf{X} = ( \mathbf{x}_1, \mathbf{x}_2, \cdots, \mathbf{x}_s )^\top \in \mathbb{R}^{s \times d},首先通过一个 Norm(·) 层进行归一化。
  • 投影层: 三个 Memory Layer 分别将归一化的 X\mathbf{X} 转换为 Q,K,VRs×dQ, K, V \in \mathbb{R}^{s \times d}。这意味着传统的 FC 层 在这里被 Memory Layer 替代。
  • 注意力计算: Q, K, V 中的 token 随后被均匀分割成多个子向量,用于 多头 机制。多头注意力 的计算方式保持不变,与 Vaswani 等人 [25] 提出的方法一致。
  • 兼容性: 任何其他高效的 自注意力 技术,如 FlashAttention [10]、Linear Attention [16] 和 KV-Cache,都可以无缝集成到 MemoryFormer 中,以进一步提高前向传播效率。

4.2.3.2. Memory Block

  • 替代 FFN: MemoryFormer 使用 Memory Block 替代标准 Transformer 中的 前馈网络 (Feed-Forward Network, FFN)
  • 结构: Memory Block 由两个连续的 Memory Layer 组成,每个 Memory Layer 前面都有一个 Norm(·) 层。
  • 归一化的重要性: Norm(·) 层至关重要,它确保哈希操作前的输入嵌入具有零均值分布。这样,sign 函数(公式 (3))可以均匀地生成 -1 和 +1,从而使得公式 (4) 的输出具有均匀分布,确保哈希表中的每个桶都能以近似相等的概率被检索,从而使 Memory Layer 的输出空间足够多样化。
  • 非线性函数: Memory Block 中省略了中间的 激活函数 (如 ReLU, GELU)。这是因为哈希操作本身就是非线性的,额外的非线性函数是冗余的。实验也证实,在两个 Memory Layer 之间去除非线性函数对性能没有影响。
  • 维度扩展与恢复: 为了与传统 FFN 模块增加模型容量和性能的设计模式对齐(通常将第一个 FC 层 的输出维度扩展 4 倍,然后通过第二个 FC 层 恢复到隐藏维度 dd),MemoryFormer 做了类似的设计:
    • 第一个 Memory Layer 的输出维度设置为 (τ+2)K(\tau+2) \cdot K。回想 d=τKd=\tau K,这意味着每个哈希表的大小为 Tk1R2τ×(τ+2)K\mathbf{T}_k^1 \in \mathbb{R}^{2^\tau \times (\tau+2) \cdot K}
    • 因此,第二个 Memory Layer 中每个子向量 zk\mathbf{z}_k 的比特宽度将比第一层中的子向量多 2 比特。
    • 第二个 Memory Layer 中的哈希表大小为 Tk2R2(τ+2)×d\mathbf{T}_k^2 \in \mathbb{R}^{2^{(\tau+2)} \times d},这使其容量比第一层大 4 倍,同时将输出嵌入的维度恢复到 dd

4.2.3.3. 计算流程概览

一个 MemoryFormer 块的计算过程如下: X=Norm(X)Q=MemoryLayerQ(X),K=MemoryLayerK(X),V=MemoryLayerV(X)Z=X+MultiHeadAttention(Q,K,V)Y=Z+MemoryLayer2(Norm(MemoryLayer1(Norm(X)))). \mathbf { X } = \mathrm { N o r m } ( \mathbf { X } ) \\ \mathrm { \mathbf { Q } } = \mathrm { MemoryLayer } _ { Q } ( \mathbf { X } ) , \mathrm { \mathbf { K } } = \mathrm { MemoryLayer } _ { K } ( \mathbf { X } ) , \mathrm { \mathbf { V } } = \mathrm { MemoryLayer } _ { V } ( \mathbf { X } ) \\ \mathbf { Z } = \mathbf { X } + \mathrm { MultiHeadAttention } ( \mathbf { Q } , \mathbf { K } , \mathbf { V } ) \\ \mathbf { Y } = \mathbf { Z } + \mathrm { MemoryLayer } _ { 2 } ( \mathrm { Norm } ( \mathrm { MemoryLayer } _ { 1 } ( \mathrm { Norm } ( \mathbf { X } ) ) ) ) .

4.2.3.4. 计算复杂度

一个标准 Transformer 块的浮点计算量是 2s2d+12sd22s^2d + 12sd^2。 而一个 MemoryFormer 块的计算量约为 2s2d+6τsd2=2s2d+6Ksd2s^2d + \frac{6}{\tau}sd^2 = 2s^2d + 6Ksd。 这表明 MemoryFormer 消除了 FC 层 产生的计算量,将其降低了一个数量级。现在,绝大部分计算负载来自于 多头注意力 操作。

5. 实验设置

5.1. 数据集

为了进行公平比较,本文使用 Pythia [3],一个开发完善的 LLM 训练框架,它提供了完全可用的数据集和详细的模型超参数。

  • 训练数据: The Pile [12] 数据集。

    • 规模: 包含 825 GiB (千兆字节) 的语料库。
    • 特点: 由 22 个多样化的高质量子集组成,这些子集要么是预先存在的,要么是从专业和学术领域构建的。
    • 选择原因: 该数据集全面且多样化,适合训练 LLM,且 Pythia 框架基于此训练,方便进行基线对比。
  • 评估任务: 选择了六个广泛使用的评估任务,涵盖了知识和推理能力,形成了评估 LLM 综合能力的全面基准。

    • PIQA [4]: Physics Intuition Question Answering (物理直觉问答)。评估模型对物理常识的推理能力。
    • WinoGrande [23, 24]: 旨在解决 Winograd Schema Challenge 的大规模对抗性数据集。评估模型在指代消解和常识推理方面的能力。
    • WSC [24]: Winograd Schema Challenge。一个经典的常识推理任务,需要模型理解歧义句子中的代词指代。
    • ARC-E (ARC-Easy) [9]: AI2 Reasoning Challenge - Easy (AI2 推理挑战 - 简单)。科学问答数据集,测试模型对科学知识的理解和推理。
    • ARC-C (ARC-Challenge) [9]: AI2 Reasoning Challenge - Challenge (AI2 推理挑战 - 挑战)。与 ARC-E 类似,但问题更具挑战性,通常需要更深层次的推理。
    • LogiQA [19]: 一个具有逻辑推理能力的机器阅读理解挑战数据集。评估模型在逻辑推理方面的能力。

5.2. 评估指标

本文主要关注模型在不同任务上的准确率以及 FLOPs

  • 准确率 (Accuracy):

    • 概念定义: 准确率是衡量分类模型性能的常用指标,表示模型正确预测的样本数量占总样本数量的比例。它量化了模型在所有类别上分类的正确性。
    • 数学公式: Accuracy=Number of Correct PredictionsTotal Number of Predictions \text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}}
    • 符号解释:
      • Number of Correct Predictions: 模型做出正确预测的样本数量。
      • Total Number of Predictions: 模型进行预测的总样本数量。
  • FLOPs (Floating Point Operations):

    • 概念定义: FLOPs 是衡量计算复杂度的指标,指模型在执行一次前向传播或特定操作时所需的浮点运算次数。它量化了模型在计算上的开销,通常用于比较不同模型的计算效率。较低的 FLOPs 意味着模型运行更快、能耗更低。
    • 数学公式: FLOPs 的计算通常针对特定操作或模型层,例如:
      • 矩阵乘法 XW\mathbf{X} \mathbf{W} (其中 XRs×d,WRd×h\mathbf{X} \in \mathbb{R}^{s \times d}, \mathbf{W} \in \mathbb{R}^{d \times h}) 的 FLOPs 约为 2×s×d×h2 \times s \times d \times h (每次乘法和加法算作 2 次浮点运算)。
      • Transformer 块的 FLOPs 是由其组成部分的 FLOPs 之和。
    • 符号解释:
      • ss: 序列长度。
      • dd: 隐藏维度 (hidden size)。
      • hh: 输出维度。
      • 具体数值根据计算类型(乘法、加法等)和实现细节而定。
  • PPL (Perplexity):

    • 概念定义: 困惑度是评估语言模型性能的常用指标,尤其是在语言建模任务中。它量化了模型预测下一个 token 的不确定性。一个较低的困惑度表示模型对测试数据有更高的置信度,即模型生成与测试数据相似的文本的能力更强。困惑度可以理解为模型在预测序列中每个词时平均需要猜测多少个词。
    • 数学公式: PPL=exp(1Ni=1NlogP(wiw1,,wi1)) \text{PPL} = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \dots, w_{i-1})\right) 或更普遍地,对于一个给定长度为 NN 的序列 W=(w1,w2,,wN)W = (w_1, w_2, \dots, w_N),其困惑度计算为: PPL(W)=(i=1N1P(wiw1,,wi1))1N=exp(1NlogP(W)) \text{PPL}(W) = \left( \prod_{i=1}^{N} \frac{1}{P(w_i | w_1, \dots, w_{i-1})} \right)^{\frac{1}{N}} = \exp\left( - \frac{1}{N} \log P(W) \right) 其中,P(W) = \prod_{i=1}^{N} P(w_i | w_1, \dots, w_{i-1})
    • 符号解释:
      • NN: 序列中 token 的总数。
      • wiw_i: 序列中的第 iitoken
      • P(wiw1,,wi1)P(w_i | w_1, \dots, w_{i-1}): 在给定前 i-1token 的情况下,模型预测第 iitoken 的概率。
      • logP(W)\log P(W): 序列 WW 的对数概率。

5.3. 对比基线

本文主要将 MemoryFormer 与以下模型进行比较:

  • Pythia 系列模型 [3]:
    • Pythia-70M (70 Million 参数)
    • Pythia-160M (160 Million 参数)
    • Pythia-410M (410 Million 参数) 这些是完全开源的 LLM,提供了详细的训练信息和检查点,便于进行公平的复现和对比。MemoryFormer 在构建时,保持了与相应 Pythia 模型相同的隐藏维度 (hidden size) 和层数。
  • 其他高效 Transformer 方法:
    • Linformer [26]
    • Cosformer [22]
    • Performer [8] 这些方法主要通过修改 自注意力机制 来提高效率。本文将 Pythia-410M多头注意力 模块替换为这些高效 注意力 模块的实现,以进行对比。

超参数设置:

  • Memory Layer 超参数:
    • τ\tau (每个子向量的比特宽度) 固定为 8。
    • KK (哈希表的数量) 根据模型大小而变化:
      • MemoryFormer-tiny (Pythia-70M 对应): K=64K=64
      • MemoryFormer-small (Pythia-160M 对应): K=96K=96
      • MemoryFormer-base (Pythia-410M 对应): K=128K=128
  • 学习率: 考虑到哈希表梯度的稀疏性,MemoryFormer 的学习率设置为对应 Pythia 模型基线学习率的 3 倍。
  • 唯一 FC 层 MemoryFormer 中唯一保留的 全连接层 是分类器头 (classifier head)。

6. 实验结果与分析

6.1. 核心结果分析

6.1.1. 不同规模模型下的性能与效率

本文在不同规模的 Pythia 模型(70M、160M、410M)上构建了相应的 MemoryFormer 模型(MF-tinyMF-smallMF-base),并在六个 NLP 基准任务上进行了零样本评估。

以下是原文 Table 1 的结果: 以下是原文 Table 1 的结果:

ModelPythia-70MMF-tinyPythia-160MMF-smallPythia-410MMF-base
Layers6612122424
Hidden Size51251276876810241024
FLOPs w/o Attn.6.4 G0.4 G14.5 G1.0 G25.8 G1.6 G
Total FLOPs10.7 G4.7 G20.9 G7.4 G34.4 G10.2 G
PIQA0.5850.6020.6180.6420.6750.698
WinoGrande0.5110.5220.4970.5230.5340.546
WSC0.3650.3750.3650.3940.4710.385
ARC-E0.3800.4370.4400.4610.5170.585
ARC-C0.1770.2280.2010.2470.2020.259
LogiQA0.2320.2600.2100.2720.2090.272
Avg.0.3750.4040.3890.4230.4350.458

分析:

  • 计算量显著减少:
    • FLOPs w/o Attn. (不含 注意力FLOPs): MemoryFormer 模型(MF-tinyMF-smallMF-base)相比对应的 Pythia 基线模型,这一部分的 FLOPs 减少了一个数量级。例如,Pythia-70M 为 6.4 G FLOPs,而 MF-tiny 仅为 0.4 G FLOPs。这直接证明了 Memory Layer 替代 全连接层 的巨大效率提升。
    • Total FLOPs (总 FLOPs): MemoryFormer 的总 FLOPs 也大幅下降。例如,Pythia-70M 为 10.7 G FLOPsMF-tiny 为 4.7 G FLOPs,减少了约 56%。Pythia-410M 为 34.4 G FLOPsMF-base 为 10.2 G FLOPs,减少了约 70%。这表明 MemoryFormer 成功地将 全连接层 的计算瓶颈转化为内存查找,使得 MHA 成为主要的计算开销。
  • 性能提升: 在所有三种模型规模下,MemoryFormer 在所有基准测试任务上的平均准确率都优于 Pythia 基线。例如,Pythia-70M 的平均准确率为 0.375,而 MF-tiny 为 0.404。Pythia-410M 为 0.435,MF-base 为 0.458。这表明 MemoryFormer 在大幅降低计算量的同时,不仅没有牺牲性能,反而获得了更好的性能。

6.1.2. 与现有高效 Transformer 方法的比较

本文将 MemoryFormerLinformerCosformerPerformer 等现有高效 Transformer 方法进行了比较,这些方法主要侧重于优化 自注意力机制

以下是原文 Table 2 的结果: 以下是原文 Table 2 的结果:

ModelPythia-410MLinformercosFormerPerformerMemoryFormer-base
FLOPs34.4 G26.1G30.0 G26.710.2 G
PIQA0.6750.5270.5220.6430.698
WinoGrande0.5340.5110.5060.4960.546
WSC0.4710.6350.6050.4330.385
ARC-E0.5170.2650.2670.4700.585
ARC-C0.2020.2440.2630.2310.259
LogiQA0.2090.2070.2640.2360.272
Avg.0.4350.3980.4050.4180.458

分析:

  • FLOPs 优势: MemoryFormer-baseFLOPs (10.2 G) 远低于 Pythia-410M (34.4 G) 和其他高效 Transformer 方法 (Linformer 26.1 G, cosFormer 30.0 G, Performer 26.7 G)。这再次强调了 MemoryFormer 在计算效率上的显著优势。
  • 性能优势: MemoryFormer-base 的平均准确率 (0.458) 不仅高于 Pythia-410M (0.435),也高于所有比较的高效 Transformer 方法。这些高效 注意力 方法虽然实现了 FLOPs 减少,但通常伴随着显著的性能下降(例如,LinformercosFormer 的平均准确率都低于基线 Pythia-410M)。MemoryFormer 则能够在大幅减少计算的同时,保持甚至提升性能。
  • 新的优化方向: 实验结果证实,即使现有方法能够减少 自注意力 的计算成本,但在大多数实际场景中,全连接层 仍然是主要的计算瓶颈。MemoryFormer 通过在嵌入空间中利用 Memory Layer 替代 FC 层,提供了一个最小化 LLM FLOPs 的新解决方案。

6.1.3. 哈希桶分布可视化

下图(原文 Figure 4)展示了哈希表中的每个桶被检索的频率分布。

Figure 4: The frequency at which each bucket in the hash table is retrieved. 分析:

理想情况下,在 Memory Layer 中,希望所有子向量 zk\mathbf{z}_k 都能均匀地哈希到各个桶中,而不是集中到少数热门桶,这样可以最大化 Memory Layer 的输出空间多样性和模型容量。图中展示了 MemoryFormer-tiny 模型中,Q, K, V 投影层以及 FFN 模块中两个层的哈希桶检索频率分布。结果显示,每个桶被哈希到的次数通常是均匀的,这表明 MemoryFormer 能够有效地利用哈希表的容量。

6.2. 消融实验/参数分析

6.2.1. τ\tauKK 的权衡

本文研究了不同 (τ,K)(\tau, K) 组合对模型性能和效率的影响,其中 τ\tau 是每个子向量的比特宽度,KK 是哈希表的数量。模型隐藏维度保持不变。

以下是原文 Table 3 的结果: 以下是原文 Table 3 的结果:

dτKVal. PPL↓FLOPsMemory Size
512412819.010.14 G2.1 MB
51286418.820.07 G16.8 MB
510105118.670.06 G53.5 MB

分析:

  • τ\tau 越大,性能越好: 随着 τ\tau(子向量比特宽度)的增加,哈希表的容量呈指数级增长 (2τ2^\tau),模型性能持续提升,困惑度 (Val. PPL) 持续下降。例如,从 τ=4\tau=4 (PPL 19.01) 到 τ=8\tau=8 (PPL 18.82) 再到 τ=10\tau=10 (PPL 18.67)。
  • 内存消耗与计算复杂度的权衡: 增大 τ\tau 会急剧增加 Memory Layer 的内存消耗。同时,FLOPs 也会受到影响。例如,当 τ=4,K=128\tau=4, K=128 时,内存大小为 2.1 MB,FLOPs 为 0.14 G。当 τ=8,K=64\tau=8, K=64 时,内存大小增至 16.8 MB,但 FLOPs 降至 0.07 G。当 τ=10,K=51\tau=10, K=51 时,内存大小进一步增至 53.5 MB,FLOPs 降至 0.06 G。
  • 最佳选择: 为了在效率和内存使用之间取得平衡,本文认为 τ=8\tau=8 是一个较好的选择,因为它在合理的内存开销下实现了显著的性能提升和 FLOPs 降低。

6.2.2. 更大的学习率

本文推测由于哈希表梯度的稀疏性,一些桶在训练过程中可能无法及时更新,因此更大的学习率可能有助于弥补梯度不足的情况。

以下是原文 Table 4 的结果: 以下是原文 Table 4 的结果:

LRVal. PPL ↓
1e-319.86
2e-319.07
3e-318.82
4e-318.84

分析:

  • 学习率对性能的影响: 实验结果表明,使用 MemoryFormer-tiny 进行 8000 步训练,当学习率增加到基线学习率 (1e-3) 的 3 倍 (3e-3) 时,困惑度 (Val. PPL) 达到最佳值 18.82。过高的学习率 (如 4e-3) 则会导致性能略微下降。
  • 验证假设: 这证实了作者关于稀疏梯度需要更大学习率的假设,较大的学习率有助于哈希表的有效更新。

6.2.3. Memory Block 中的扩展比特

为了增加模型容量,Memory Block 的第一个 Memory Layer 的输出嵌入维度被扩展,这导致第二个 Memory Layer 中的子向量 zk\mathbf{z}_k 的比特宽度增加。

以下是原文 Table 5 的结果: 以下是原文 Table 5 的结果:

#Expanding Bitτ′Val. PPL ↓Size of Hash TablesMemory Size
0819.89TM 1 R256×512 TM 2 R256×51233.6 MB
1919.26k k R256×576' R512×51252.4 MB
21018.82k E R256×640' TM 2 R1024×512 k88.1 MB
31118.54R256×704' T M R2048×512 , k157.3 MB

分析:

  • 扩展比特与性能: 随着 Memory Block 中扩展比特数 (#Expanding Bit) 的增加,第二个 Memory Layer 的子向量比特宽度 τ\tau' 增加,模型的验证困惑度 (Val. PPL) 持续下降,表明模型容量和性能得到提升。例如,从 0 扩展比特 (PPL 19.89) 到 3 扩展比特 (PPL 18.54)。
  • 内存消耗: 尽管性能提升,哈希表消耗的内存空间呈指数级增长。从 0 扩展比特的 33.6 MB 增加到 3 扩展比特的 157.3 MB。
  • 权衡选择: 考虑到空间复杂度和性能之间的权衡,本文选择 2 作为 Memory Block 中的扩展比特数。

6.2.4. 移除 Memory Block 中的非线性函数

本文在方法论中提到 MemoryFormer 移除了所有的 激活函数。通过实验验证在 Memory Block 的两个连续 Memory Layer 之间插入 GeLU 层的效果。

以下是原文 Table 6 的结果: 以下是原文 Table 6 的结果:

ModelPIQAWinoGrandeWSCARC-EARC-CLogiQAAvg.
MF-tiny0.6020.5220.3750.4370.2280.2600.404
MF-tiny w/ GeLU0.6010.5230.3750.4360.2270.2600.404

分析:

  • 非线性的冗余性: 实验结果显示,在 Memory Block 的两个 Memory Layer 之间添加 GeLU 层,几乎得到了与基线 MemoryFormer 相同的测试分数。
  • 验证设计: 这证实了哈希操作本身已经引入了足够的非线性,额外的非线性 激活函数 是冗余的,并且移除它们并不会影响模型的性能。

7. 总结与思考

7.1. 结论总结

本文提出了 MemoryFormer,一种新颖的 Transformer 架构,旨在从一个全新的视角显著降低模型的计算复杂度 (FLOPs)。与现有主要优化 多头注意力 (MHA) 操作的方法不同,MemoryFormer 关注于移除 Transformer 模型中计算量最大的 全连接层 (FC layers)。通过引入 Memory Layer,它利用 局部敏感哈希 (LSH) 和内存中的 查找表 来替代计算密集型的 矩阵乘法 操作。实验结果表明,MemoryFormer 能够在显著减少计算量(一个 MemoryFormer 块的 FLOPs 相比基线 Transformer 减少高达 70%)的同时,在多个 NLP 基准测试上保持甚至超越基线 Transformer 的性能。这一成果不仅为 LLM 的效率优化提供了一个新的方向,也为未来并行计算平台的硬件设计提供了潜在的指导意义,即更好地利用 CPURAM 资源来分担 GPU 的计算负担。

7.2. 局限性与未来工作

论文中并没有明确指出“局限性”部分,但可以从其设计和实验中推断一些潜在的局限性和未来工作:

  • 内存消耗: 尽管通过分块策略将哈希表的内存消耗从 O(2dh)O(2^d h) 降低到 O(K2τh)O(K 2^\tau h),但当 τ\tau 增大时,内存消耗仍然呈指数级增长。例如,在消融实验中,从 τ=8\tau=8τ=11\tau=11,内存消耗从 16.8 MB 增加到 157.3 MB。对于更大的模型或更精细的粒度(更大的 τ\tau),内存管理仍可能是一个挑战。如何进一步优化内存效率或引入更复杂的 LSH 变体以在更小的 τ\tau 下实现更好性能,是可能的未来方向。
  • 哈希表的训练稳定性: 论文提到哈希表的梯度是稀疏的,因此需要更大的学习率。这表明训练过程可能对超参数(如学习率)比较敏感。未来的工作可以探索更鲁棒的训练策略或梯度传播机制,以确保哈希表参数的有效和稳定更新。
  • 泛化性与复杂哈希函数: 本文使用了非常简化的 LSH 函数(基于 sign 函数)。虽然这种简单性有助于降低计算量,但其捕获输入向量之间相似性的能力可能有限。更复杂的 LSH 函数(如果能在不增加显著计算开销的情况下实现)可能带来更好的性能,但这也需要仔细权衡。
  • 推理延迟与硬件适配: 尽管 FLOPs 显著降低,但实际的推理延迟可能受限于内存访问速度和 CPU/GPU 之间的 I/O 带宽。论文提到这为硬件设计提供了指导意义(如更大的总线宽度和更高的缓存命中率),这暗示了当前硬件可能并非为这种计算模式最优设计。未来的工作可能需要更深入地研究 MemoryFormer 在不同硬件上的实际性能瓶颈,并推动定制化硬件或更优化的软件实现。
  • 温度超参数 tt 的选择: 概率 p(zk)p(\mathbf{z}_k) 的计算依赖于温度超参数 tt。论文中并未详细探讨 tt 的选择及其对模型性能的影响。未来的工作可以对 tt 进行更深入的消融研究或设计自适应的 tt
  • 仅替换 FC Layer 尽管本文的创新点在于替换 FC 层,但 MHA 仍然是计算负载的主要来源。将 MemoryFormer 的思想与现有的高效 MHA 技术(如 FlashAttentionLinear Attention 等)结合,可能会带来进一步的效率提升。

7.3. 个人启发与批判

7.3.1. 个人启发

这篇论文提供了一个非常新颖且有前瞻性的视角来优化 Transformer 模型,它不仅仅局限于 自注意力机制 的优化,而是将目光投向了长期被视为“理所当然”但实际计算量巨大的 全连接层

  • 计算瓶颈的重新思考: 论文的第一个核心观察——全连接层 在大多数实际场景下是 Transformer 的主要计算瓶颈——是一个深刻的见解。这提醒研究者不要被传统的思维定势所限制,要不断审视模型的各个组成部分,找到真正的效率瓶颈。
  • 内存-计算权衡的极致利用: MemoryFormer 极致地利用了内存和计算之间的权衡。通过将计算密集型的 矩阵乘法 转换为成本更低的内存查找和聚合,它为“以空间换时间”提供了一个非常优雅且高效的实例。这启发我们,在设计深度学习模型时,可以更大胆地探索如何将计算任务转化为内存操作,尤其是在现代硬件 RAM 容量不断增大的背景下。
  • 非线性操作的创新: 传统的神经网络依赖于线性层和激活函数交替实现非线性。MemoryFormer 证明了 LSH 这样的离散查找操作本身就能提供足够的非线性,甚至可以替代传统的激活函数。这为设计更高效、更具解释性的神经网络结构开辟了新的道路。
  • 硬件-算法协同设计的潜力: 论文提到其工作对硬件设计具有指导意义,这强调了算法创新与硬件发展之间的紧密联系。未来的深度学习系统可能不仅仅是提高 GPU 的计算能力,而是会更深入地集成 CPURAM 等多模态计算资源,并为 MemoryFormer 这种内存密集型模型定制更优化的数据路径和内存访问机制。
  • 从稀疏性到性能的洞察: 梯度稀疏性导致需要更大学习率的发现,也提供了一个实用的训练技巧,这对于处理具有离散或稀疏更新机制的模型具有通用指导意义。

7.3.2. 批判

  • 实际推理延迟的衡量: 尽管 FLOPs 大幅降低,但 FLOPs 只是理论计算量。实际的推理速度(latency)会受到内存带宽、缓存命中率、CPU-GPU 协同效率等因素的严重影响。论文虽然讨论了 FLOPs,但缺乏对实际 吞吐量 (throughput)延迟 的直接衡量,特别是在不同硬件配置下。例如,哈希表访问可能导致缓存未命中,从而引入显著的延迟。一个全面的评估应该包括在主流推理硬件(如 GPUCPU)上的实际 端到端 (end-to-end) 速度测试。

  • “计算少”与“实现简单”的差异: 论文强调“检索数据块是一个便宜的操作”。这在理论上是正确的,但实际实现时,尤其是在 GPU 上,如何高效地进行哈希计算、如何管理和访问 KK 个大小为 2τ×h2^\tau \times h 的哈希表,如何处理 Memory Layer 的梯度反向传播(涉及 concat 和复杂的导数链式法则),都可能比一个高度优化的 cuBLAS 矩阵乘法库更具挑战性。其代码实现和运行时效率需要进一步的验证和优化。

  • 哈希冲突的鲁棒性: LSH 的本质是允许相似项发生冲突。然而,如果哈希冲突过于频繁或将不相似的项映射到一起,可能会损害模型的表达能力。虽然论文通过 p(zk)p(\mathbf{z}_k) 进行加权,但在极端情况下,哈希函数的“粗粒度”是否会成为模型学习复杂模式的限制?

  • 超参数 KKτ\tau 的选取: 论文中的消融实验表明 KKτ\tau 对性能和内存有显著影响,但其选取仍然是基于启发式和实验权衡。是否存在更系统或自适应的方法来确定这些参数,尤其是在模型规模进一步扩大时,是一个值得探讨的问题。

  • 可解释性: 尽管 LSH 方法本身具有一定的可解释性(通过查看哈希桶内容),但 MemoryFormer 模型的整体可解释性如何?Memory Layer 如何学习并近似 全连接层 的复杂映射?这需要更深入的分析工具来理解。

    总的来说,MemoryFormer 是一项富有想象力的工作,它挑战了 Transformer 架构中的传统假设,并为 LLM 的效率问题提供了一个大胆的新解决方案。其核心思想具有巨大的潜力,但也需要在实际部署、软件优化和理论分析方面进行更多深入的研究。

相似论文推荐

基于向量语义检索推荐的相关论文。

暂时没有找到相似论文。