论文状态:已完成

Memformer: A Memory-Augmented Transformer for Sequence Modeling

发表:2020/10/14
原文链接PDF 下载
价格:0.100000
价格:0.100000
已有 3 人读过
本分析由 AI 生成,可能不完全准确,请以原文为准。

TL;DR 精炼摘要

Memformer是一种记忆增强型Transformer,针对标准Transformer在长序列建模中的效率问题,通过外部动态记忆提高了信息编码与检索的能力。该模型实现了线性时间复杂度与常数空间复杂度,并通过记忆回放反向传播优化了内存需求。实验结果表明,Memformer在推理时的内存减少8.1倍,速度提升3.2倍,同时性能相当。

摘要

Transformers have reached remarkable success in sequence modeling. However, these models have efficiency issues as they need to store all the history token-level representations as memory. We present Memformer, an efficient neural network for sequence modeling, that utilizes an external dynamic memory to encode and retrieve past information. Our model achieves linear time complexity and constant memory space complexity when processing long sequences. We also propose a new optimization scheme, memory replay back-propagation (MRBP), which promotes long-range back-propagation through time with a significantly reduced memory requirement. Experimental results show that Memformer has achieved comparable performance compared to the baselines by using 8.1x less memory space and 3.2x faster on inference. Analysis of the attention pattern shows that our external memory slots can encode and retain important information through timesteps.

思维导图

论文精读

中文精读

1. 论文基本信息

1.1. 标题

Memformer: 一种用于序列建模的记忆增强型 Transformer (Memformer: A Memory-Augmented Transformer for Sequence Modeling)

1.2. 作者

Qingyang Wu, Zhenzhong Lan, Kun Qian, Jing Gu, Alborz Geramifard, Zhou Yu

隶属机构包括:加州大学戴维斯分校 (University of California, Davis)、西湖大学 (Westlake University) 和 Facebook AI。

1.3. 发表期刊/会议

该论文以预印本形式发布在 arXiv 上。 发布时间 (UTC): 2020-10-14T00:00:00.000Z

1.4. 摘要

Transformer 模型在序列建模 (sequence modeling) 方面取得了显著成功。然而,这些模型存在效率问题,因为它们需要存储所有历史词元级 (token-level) 表示作为记忆。本文提出了 Memformer,一种高效的序列建模神经网络,它利用外部动态记忆 (external dynamic memory) 来编码和检索过去的信息。该模型在处理长序列时实现了线性时间复杂度 (linear time complexity) 和常数空间复杂度 (constant memory space complexity)。我们还提出了一种新的优化方案,记忆回放反向传播 (memory replay back-propagation, MRBP),它显著减少了内存需求,促进了通过时间的长距离反向传播 (long-range back-propagation through time)。实验结果表明,Memformer 在推理时使用的内存空间减少了 8.1 倍,速度提升了 3.2 倍,同时达到了与基线模型相当的性能。对注意力模式 (attention pattern) 的分析表明,我们的外部记忆槽 (external memory slots) 能够通过时间步编码并保留重要信息。

1.5. 原文链接

https://arxiv.org/abs/2010.06891 PDF 链接: https://arxiv.org/pdf/2010.06891.pdf

2. 整体概括

2.1. 研究背景与动机

论文试图解决的核心问题: Transformer 模型虽然在序列建模中表现出色,但其主要瓶颈在于处理长序列时的效率问题。具体来说:

  1. 二次计算复杂度 (Quadratic Computation Complexity): 标准 Transformer 的自注意力 (self-attention) 机制需要计算序列中所有词元对之间的依赖关系,导致计算复杂度为 O(N2)\mathcal{O}(N^2),其中 NN 是序列长度。这使得处理非常长的序列变得成本高昂。
  2. 线性内存空间复杂度 (Linear Memory Space Complexity): 大多数现有的高效 Transformer 变体虽然降低了计算复杂度,但仍然需要存储所有历史词元级表示,导致内存消耗随序列长度线性增长,这对于长序列来说是不可接受的。
  3. 记忆瓶颈 (Memory Bottleneck): 传统的循环神经网络 (RNN) 及其变体(如 LSTM 和 GRU)使用内部压缩状态向量作为记忆,但它们在保留长期信息方面存在记忆瓶颈。而早期的外部记忆网络 (如 NTM、DNC) 过于复杂,训练不稳定,未能广泛应用。
  4. Transformer-XL 和 Compressive Transformer 的局限性: 尽管这些模型引入了循环和记忆的概念来处理长上下文,但它们存储的是原始隐藏状态 (raw hidden states),不能有效压缩高层信息,导致需要巨大的内存才能表现良好,并且存在理论上的最大时间范围 (theoretical maximum temporal range),无法保证模型能够看到所有过去的词元。

为什么这个问题在当前领域是重要的? 随着深度学习模型在各种任务中处理的序列长度越来越长(例如,长文档理解、对话系统、图像序列生成等),现有 Transformer 的效率瓶颈变得愈发突出。解决这些效率问题对于扩展 Transformer 的应用范围和提高其在实际部署中的可行性至关重要。

这篇论文的切入点或创新思路是什么? Memformer 的创新思路在于结合了 Transformer 架构和外部动态记忆系统,以实现高效的长期序列建模。其核心切入点包括:

  1. 外部动态记忆 (External Dynamic Memory): 引入一个固定大小的外部记忆槽集合,用于存储和检索高层次的、压缩的历史信息,而非原始词元表示。这使得模型在理论上具有无限的时间记忆范围,并能实现常数内存空间复杂度。
  2. 记忆读取与写入机制 (Memory Reading and Writing Mechanisms): 设计了专门的模块来管理记忆的交互。记忆读取通过交叉注意力 (cross attention) 机制进行,而记忆写入则通过槽注意力 (slot attention) 和遗忘机制 (forgetting mechanism) 来更新和清理记忆。
  3. 记忆回放反向传播 (Memory Replay Back-Propagation, MRBP): 针对传统通过时间反向传播 (Back-Propagation Through Time, BPTT) 在处理大记忆系统时内存成本过高的问题,提出了一种新的优化方案,通过重放记忆来显著减少训练时的内存需求,同时保持梯度流的完整性。

2.2. 核心贡献/主要发现

论文最主要的贡献:

  1. 提出了 Memformer 模型: 一种新型的记忆增强型 Transformer,它通过固定大小的外部动态记忆,实现了序列建模的线性时间复杂度和常数内存空间复杂度,解决了传统 Transformer 在长序列上的效率瓶颈。
  2. 设计了记忆读取与写入模块: 引入了基于交叉注意力 (cross-attention) 的记忆读取和基于槽注意力 (slot attention) 与偏置记忆归一化 (Biased Memory Normalization, BMN) 遗忘机制的记忆写入模块,以有效地管理和更新外部记忆。
  3. 提出了记忆回放反向传播 (MRBP) 算法: 一种高效的训练方案,显著降低了训练带有大记忆表示的循环神经网络所需的内存成本,同时保持了接近标准 BPTT 的训练速度。

论文得出了哪些关键的结论或发现?

  1. 显著的效率提升: Memformer 在推理时使用的内存空间比基线模型减少了 8.1 倍,速度提升了 3.2 倍,显示出其在计算效率上的巨大优势。
  2. 竞争性的性能: 尽管效率大幅提升,Memformer 在自回归图像生成和语言建模任务上仍取得了与 Transformer 和 Transformer-XL 相当或更好的性能。
  3. 长期信息保留能力: 对注意力模式的分析表明,Memformer 的外部记忆槽能够有效编码和保留来自遥远过去的重要信息,验证了其动态记忆机制的有效性。
  4. 各组件的重要性: 消融实验 (ablation studies) 证明了记忆模块、遗忘机制、记忆写入温度 (memory writer temperature) 和多头注意力 (multi-head attention) 等组件对最终性能的贡献。

3. 预备知识与相关工作

3.1. 基础概念

3.1.1. Transformer

Transformer 模型 (Vaswani et al., 2017) 是一种基于注意力机制的神经网络架构,彻底改变了自然语言处理 (NLP) 领域。它摒弃了传统的循环 (recurrence) 和卷积 (convolution) 结构,完全依赖自注意力 (self-attention) 机制来捕捉序列中的长距离依赖关系。

自注意力机制 (Self-Attention Mechanism): 自注意力是 Transformer 的核心组件,它允许模型在处理序列的某个词元时,能够同时关注序列中的所有其他词元,并根据它们之间的相关性分配不同的权重。其计算公式如下: Attention(Q,K,V)=softmax(QKTdk)V \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V 符号解释:

  • QQ (Query): 查询矩阵,由当前词元或一组词元的表示与权重矩阵 WQW_Q 相乘得到。
  • KK (Key): 键矩阵,由序列中所有词元的表示与权重矩阵 WKW_K 相乘得到。
  • VV (Value): 值矩阵,由序列中所有词元的表示与权重矩阵 WVW_V 相乘得到。
  • QKTQ K^T: 查询与键的点积,表示查询与每个键之间的相似度。
  • dk\sqrt{d_k}: 缩放因子,用于防止点积结果过大,导致 softmax 函数梯度过小。dkd_k 是键向量的维度。
  • softmax()\mathrm{softmax}(\cdot): 归一化指数函数,将相似度分数转换为注意力权重,使它们的和为 1。
  • VV: 值矩阵,注意力权重与值矩阵相乘,得到加权后的值表示。

多头注意力 (Multi-Head Attention): Transformer 进一步引入了多头注意力,它允许模型同时从不同的“表示子空间”中学习信息。具体来说,它将 Q, K, V 分别投影到 hh 个不同的子空间,然后并行地计算 hh 个注意力头,最后将这些头的输出拼接起来并再次投影,得到最终结果。

3.1.2. 循环神经网络 (Recurrent Neural Networks, RNN)

RNN 是一种专门用于处理序列数据的神经网络,通过在序列的每一步将前一步的隐藏状态 (hidden state) 作为输入,从而实现对时间依赖关系的建模。 LSTM (Long Short-Term Memory) 和 GRU (Gated Recurrent Unit): 为了解决标准 RNN 在处理长序列时出现的梯度消失 (vanishing gradient) 和梯度爆炸 (exploding gradient) 问题,LSTM 和 GRU 引入了门控机制 (gating mechanisms) 来更好地控制信息的流动,使其能够捕获更长距离的依赖关系。LSTM 具有输入门 (input gate)、遗忘门 (forget gate) 和输出门 (output gate),而 GRU 是一个简化版的 LSTM,仅有更新门 (update gate) 和重置门 (reset gate)。这些门通过学习的方式决定哪些信息应该被保留、更新或遗忘。

3.1.3. 通过时间反向传播 (Back-Propagation Through Time, BPTT)

BPTT 是一种用于训练循环神经网络的算法。它本质上是将循环网络“展开”成一个深度前馈网络,然后应用标准的反向传播算法来计算梯度。在每个时间步,BPTT 需要存储所有中间激活值 (intermediate activations) 以便在反向传播时使用。对于长序列,这会导致巨大的内存消耗,因为计算图 (computational graph) 会变得非常深。

3.1.4. 梯度检查点 (Gradient Checkpointing, GC)

梯度检查点是一种优化技术,用于减少深度神经网络训练时的内存消耗。它通过在正向传播 (forward pass) 过程中只存储计算图中的一部分关键激活值,而在反向传播 (backward pass) 过程中需要时重新计算其他激活值来节省内存。虽然它减少了内存需求,但代价是增加了计算时间,因为它需要重复一些正向传播的计算。

3.2. 前人工作

3.2.1. 稀疏注意力 (Sparse Attention)

为了解决标准 Transformer 的 O(N2)\mathcal{O}(N^2) 计算复杂度,稀疏注意力方法通过设计特定的注意力模式,使每个词元只关注序列中的一小部分词元,从而将复杂度降低到接近线性。

  • Sparse Transformer (Child et al., 2019) 使用块稀疏注意力模式,将复杂度降至 O(NN)\mathcal{O}(N\sqrt{N}),并能理论上覆盖所有过去的词元。
  • Longformer (Beltagy et al., 2020) 和 Big Bird (Zaheer et al., 2020) 进一步探索了更稀疏的注意力模式,引入了全局词元 (global tokens) 和局部注意力 (local attention) 来实现 O(N)\mathcal{O}(N) 复杂度。然而,这些方法在自回归解码 (autoregressive decoding) 设置下存在局限性,因为全局词元不能泄露未来信息,导致一个词元不能保证看到所有过去的词元。

3.2.2. 线性注意力 (Linear Attention)

线性注意力方法旨在改进自注意力中的 softmax 操作,以实现线性复杂度。

  • Linformer (Wang et al., 2020) 通过将整个序列投影到固定大小的键和值来实现 O(N)\mathcal{O}(N) 复杂度,但尚未应用于自回归解码。
  • Performer (Choromanski et al., 2020) 和 Linear Transformer (Katharopoulos et al., 2020) 使用核特征映射 (kernel feature maps) 的线性点积来替代 softmax。然而,Linear Transformer 在自回归设置下需要计算累积求和 (cumulative summation) 来聚合历史信息,这可能导致数值在长序列下变得非常大,引发溢出和梯度不稳定性。

3.2.3. 循环与记忆 (Recurrence and Memory)

这是与高效注意力方法正交的研究方向,通过引入循环机制和记忆系统来处理长上下文。

  • Transformer-XL (Dai et al., 2019) 引入了分段级循环机制 (segment-level recurrence) 和相对位置编码 (relative positional encoding),通过缓存前一层的隐藏状态作为记忆来扩展上下文。
  • Compressive Transformer (Rae et al., 2020) 在 Transformer-XL 的基础上,进一步通过压缩网络将过去的缓存隐藏状态压缩成更少的向量,以实现更长的上下文。 这些模型虽然使用了记忆,但它们存储的是原始的词元级隐藏状态,这意味着它们的记忆容量是有限的,存在理论上的最大时间范围,无法保证能看到所有遥远的过去信息。因此,在实践中,它们需要巨大的内存才能达到良好性能。

3.2.4. 动态记忆 (Dynamic Memorization)

Transformer-XL 存储词元级历史表示不同,动态记忆技术旨在实现没有理论上限的时间记忆范围。

  • Neural Turing Machine (NTM) (Graves et al., 2014) 和 Differential Neural Computer (DNC) (Graves et al., 2016) 是早期的模型,它们能够控制外部记忆资源以实现持久记忆。然而,这些模型的记忆机制复杂,导致训练速度慢且不稳定。本文提出的 Memformer 旨在提供一种更高效的动态记忆机制。

3.3. 技术演进

从最早的循环神经网络 (RNN) 及其门控变体 (LSTM、GRU) 尝试通过内部状态捕获序列依赖,到外部记忆网络 (NTM、DNC) 引入更强大的记忆存储,再到 Transformer 完全基于注意力机制。然而,Transformer 在长序列上的效率问题促生了稀疏注意力、线性注意力等改进方向。同时,Transformer-XL 和 Compressive Transformer 重新引入了循环和记忆,但仍受限于其记忆表示形式。Memformer 借鉴了外部记忆网络的概念,并将其与 Transformer 架构高效结合,形成了一种新的动态记忆增强型 Transformer,旨在克服现有方法的记忆瓶颈和效率问题。

3.4. 差异化分析

Memformer 与相关工作的主要区别和创新点在于:

  • 与传统 Transformer 及其高效变体相比: Memformer 采用固定大小的外部动态记忆,实现了线性的时间复杂度和常数的内存空间复杂度,而大多数高效 Transformer 变体仍需要线性内存空间复杂度,并且在自回归设置下可能无法保证覆盖所有历史信息。
  • 与 Transformer-XL 和 Compressive Transformer 相比: Memformer 的外部记忆存储的是高层次的、压缩的信息,而非原始隐藏状态。这使得 Memformer 理论上具有无限的时间记忆范围,并且在实践中以更小的内存实现了更好的性能。Transformer-XL 和 Compressive Transformer 由于记忆存储方式的限制,需要更大的内存,并存在理论上的最大时间范围。
  • 与 NTM/DNC 等早期动态记忆网络相比: Memformer 的记忆交互机制(基于 Transformer 的注意力)更为高效和稳定,克服了 NTM/DNC 训练复杂和缓慢的问题,使其更容易应用于下游任务。
  • MRBP 优化方案: 提出了针对大记忆循环网络训练的 MRBP 算法,有效解决了传统 BPTT 在 Memformer 这种模型中内存消耗过大的问题,这是现有梯度检查点方法的进一步改进,在内存效率和速度之间取得了更好的平衡。

4. 方法论

Memformer 是一种分段级 (segment-level) 的序列建模模型,它利用一个外部动态记忆 (External Dynamic Memory, EDM) 来存储和检索历史信息。模型通过编码器与记忆交互,解码器则用于生成下一个分段。为了解决训练时的内存问题,还引入了记忆回放反向传播 (Memory Replay Back-Propagation, MRBP) 算法。

4.1. 方法原理

Memformer 的核心思想是将长序列分割成较短的分段 (segments)。在每个时间步,模型接收一个分段作为输入,并利用一个固定大小的外部动态记忆来存储和检索过去的信息。记忆不是存储原始的词元级表示,而是存储高层次的、压缩的表示。通过这种方式,Memformer 能够实现对理论上无限长序列的建模,同时保持线性的计算复杂度和常数的内存空间复杂度。模型的编码器负责将当前分段的信息注入记忆并从记忆中检索历史信息,而解码器则根据编码器输出和记忆来预测下一个分段的词元概率。此外,一个关键的遗忘机制被引入,以确保记忆能够动态更新并过滤掉不重要的信息。

4.2. 核心方法详解

4.2.1. 分段级序列建模 (Segment-level Sequence Modeling)

传统的语言模型 (language model) 学习序列的联合概率分布,通过将每个词元 xtx_t 的概率乘以其在给定先前词元 x<tx_{<t} 条件下的概率: P(x)=tP(xtx<t)P(x) = \prod_t P(x_t | x_{<t}) 当使用大型外部记忆系统存储历史信息时,无法在每个词元级别都与记忆进行交互。Memformer 的解决方案是在分段级别处理长序列。一个序列被分成 TT 个分段,每个分段包含 LL 个词元:st={xt,1,xt,2,...,xt,L}s_t = \{x_{t,1}, x_{t,2}, ..., x_{t,L}\}

Memformer 采用 Transformer 编码器-解码器 (encoder-decoder) 架构。

  • 编码器 (Encoder) 的作用是编码当前分段 sts_t,并将信息注入到记忆 MtM_t 中,同时它还会从前一时间步的记忆 Mt1M_{t-1} 中检索过去的信息。 Mt=Encoder(st,Mt1) M_t = \mathrm{Encoder}(s_t, M_{t-1})
  • 解码器 (Decoder) 接收编码器的最终输出,并利用交叉注意力层 (cross attention layers) 预测下一个时间步分段 st+1s_{t+1} 的词元概率。 P(sts<t)=n=1:LPDecoder(xt,nxt,<n,Mt1) P(s_t | s_{<t}) = \prod_{n=1:L} P_{\mathrm{Decoder}}(x_{t,n} | x_{t,<n}, M_{t-1}) 整个模型的联合概率表示为: P(x)=t=1:TPModel(sts<t) P(x) = \prod_{t=1:T} P_{\mathrm{Model}}(s_t | s_{<t}) 在每个时间步,给定一个分段作为输入,模型需要生成下一个文本分段,并且生成的这个分段会再次作为输入反馈给模型。由于记忆存储了所有过去的信息,模型可以自回归地生成序列中的所有词元分段,从而对整个长序列进行建模。

下图(原文 Figure 1)展示了 Memformer 的整体架构,包括编码器(左侧)和解码器(右侧):

Figure 1: Memformer overall architecture for the encoder (left) and decoder (right). Transformer encoder is responsible to interact with the memory. Sequence modeling is achieved by predicting the next segment conditioned to the current segment and memory. 该图像是Memformer模型的示意图,展示了编码器(左侧)和解码器(右侧)的整体架构。编码器通过记忆读取器与外部记忆互动,解码器则通过自注意力机制处理自回归输入和模型内存。模型的计算流程涉及FeedForward层、LayerNorm层以及Cross Attention机制。

4.2.2. 外部动态记忆槽 (External Dynamic Memory Slots)

外部动态记忆 (EDM) 是一种数据结构,用于存储过去输入的高层次表示。其“动态”特性意味着模型以循环方式交互地编码和检索记忆中的信息。 Memformer 分配固定数量的 kk 个向量作为外部动态记忆。在每个时间步 tt,记忆表示为 Mt=[mt0,mt1,,mtk]M_t = [m_t^0, m_t^1, \ldots, m_t^k]。批处理 (batch) 中的每个样本都有独立的记忆表示。因此,在推理时,无论输入序列多长,内存消耗都是常数,类似于 RNN。这些向量被称为记忆槽 (memory slots),因为每个槽独立工作以拥有不同的表示。

4.2.3. 记忆读取 (Memory Reading)

对于每个输入分段序列,模型需要读取记忆以检索相关的过去信息。这通过交叉注意力机制实现: Qx,KM,VM=xWQ,MtWK,MtWVAx,M=MHAttn(Qx,KM)Hx=Softmax(Ax,M)VM \begin{array}{c} Q_x, K_M, V_M = x W_Q, M_t W_K, M_t W_V \\ A_{x,M} = \mathbf{MHAttn}(Q_x, K_M) \\ H_x = \mathbf{Softmax}(A_{x,M}) V_M \end{array} 符号解释:

  • xx: 当前输入分段的表示。

  • MtM_t: 当前时间步的外部动态记忆。

  • WQ,WK,WVW_Q, W_K, W_V: 可学习的线性投影权重矩阵,用于将输入 xx 和记忆 MtM_t 投影到查询 (Query)、键 (Key) 和值 (Value) 空间。

  • QxQ_x: 输入序列的查询矩阵。

  • KMK_M: 记忆槽的键矩阵。

  • VMV_M: 记忆槽的值矩阵。

  • MHAttn(,)\mathbf{MHAttn}(\cdot, \cdot): 多头注意力函数。在这里,输入序列的查询 QxQ_x 与记忆槽的键 KMK_M 交互计算注意力分数。

  • Ax,MA_{x,M}: 经过多头注意力计算的注意力分数。

  • Softmax()\mathbf{Softmax}(\cdot): 归一化函数,将注意力分数转换为注意力权重。

  • HxH_x: 最终的隐藏状态,通过注意力权重对记忆槽的值 VMV_M 进行加权求和得到。

    这个过程允许模型学习记忆的复杂关联。记忆读取会在每个编码器层中发生多次,以确保更高概率地成功检索所需信息。

下图(原文 Figure 2)展示了记忆读取过程:

Figure 2: Memory Reading. The input sequence \(x\) attends over all the memory slots to retrieve the history information. 该图像是一个示意图,展示了在Memformer中,输入序列xx如何通过交叉注意力机制读取外部动态记忆mt中的信息,以及注意力权重的映射过程。该模型通过注意力槽编码历史信息,有助于减少内存消耗。

4.2.4. 记忆写入 (Memory Writing)

记忆写入发生在编码器的最后一层,旨在将当前分段的高层上下文表示存储到记忆中。这个过程包括通过槽注意力 (slot attention) 更新记忆信息和通过遗忘机制清理不重要的信息。为了更好地提取序列表示,实践中会在输入序列中附加一些分类词元 (classification tokens)。

4.2.4.1. 通过记忆槽注意力更新 (Update via Memory Slot Attention) 记忆更新通过槽注意力模块完成。每个记忆槽 mim^i 独立地被投影为查询 QmiQ_{m^i} 和键 KmiK_{m^i}。当前分段词元表示 xx 被投影为键 KxK_x 和值 VxV_x。槽注意力意味着每个记忆槽只能关注它自己和词元表示,而不能直接关注其他记忆槽,以避免记忆槽之间的直接干扰。 Qmi,Kmi=miWQ,miWKKx,Vx=xWK,xWVAmi=MHAttn(Qmi,[Kmi;Kx]) \begin{array}{rcl} Q_{m^i}, K_{m^i} & = & m^i W_Q, m^i W_K \\ K_x, V_x & = & x W_K, x W_V \\ A_{m^i}^{\prime} & = & \mathbf{MHAttn}(Q_{m^i}, [K_{m^i}; K_x]) \end{array} 符号解释:

  • mim^i: 记忆 MtM_t 中的第 ii 个记忆槽向量。

  • xx: 编码器输出的当前分段词元表示。

  • WQ,WK,WVW_Q, W_K, W_V: 可学习的线性投影权重矩阵。

  • Qmi,KmiQ_{m^i}, K_{m^i}: 第 ii 个记忆槽的查询和键。

  • Kx,VxK_x, V_x: 输入序列的键和值。

  • [Kmi;Kx][K_{m^i}; K_x]: 将第 ii 个记忆槽的键与输入序列的键进行拼接 (concatenation)。

  • MHAttn(,)\mathbf{MHAttn}(\cdot, \cdot): 多头注意力函数。这里,每个记忆槽的查询 QmiQ_{m^i} 关注它自己的键 KmiK_{m^i} 和输入序列的键 KxK_x

  • AmiA_{m^i}^{\prime}: 初始的注意力对数 (attention logits)。

    在计算最终注意力分数时,原始注意力对数 AmiA_{m^i}^{\prime} 会除以一个温度参数 τ\tau (τ<1\tau < 1)。这个操作会使注意力分布变得更“尖锐” (sharper),从而使写入过程更专注于少数槽或词元输出。 Ami=exp(Ami/τ)jexp(Aj/τ) A_{m^i} = \frac{\exp(A_{m^i}^{\prime} / \tau)}{\sum_j \exp(A_{j}^{\prime} / \tau)} 符号解释:

  • τ\tau: 温度参数,用于调整注意力分布的锐度。当 τ<1\tau < 1 时,注意力分布变得更集中;当 τ=1\tau = 1 时,为标准 softmax;当 τ>1\tau > 1 时,注意力分布变得更平缓。

  • exp()\exp(\cdot): 指数函数。

  • jexp(Aj/τ)\sum_j \exp(A_j^{\prime} / \tau): 归一化因子,对所有可能的注意力对数进行求和。

  • AmiA_{m^i}: 经过温度缩放和 softmax 归一化后的注意力权重。

    最后,通过注意力机制收集信息,得到下一个时间步的记忆槽 mt+1im_{t+1}^i: mt+1i=Softmax(Ax,M)[mti;Vx] m_{t+1}^{i}{'} = \mathrm{Softmax}(A_{x,M}) \left[m_t^i; V_x\right] 符号解释:

  • mt+1im_{t+1}^{i}{'}: 更新后的第 ii 个记忆槽的候选值。

  • Softmax(Ax,M)\mathrm{Softmax}(A_{x,M}): 注意力权重,用于对记忆槽 mtim_t^i 和输入序列的值 VxV_x 进行加权。

  • [mti;Vx][m_t^i; V_x]: 将前一时间步的第 ii 个记忆槽 mtim_t^i 与输入序列的值 VxV_x 进行拼接。

    这个注意力机制帮助每个记忆槽选择是保留其旧信息还是用新信息进行更新。 下图(原文 Figure 3)展示了记忆写入过程:

Figure 3: Memory Writing. Each memory slot attends over itself and the input sequence representations to produce the next timestep's memory slot.

4.2.4.2. 记忆写入器实现 (Implementation of Memory Writer) 由于每个记忆槽独立存储信息,作者设计了一种特殊的稀疏注意力模式。记忆中的每个槽只能关注它自己和编码器输出。这旨在更长时间地保留每个槽中的信息。当一个槽在写入时只关注自己时,其信息在下一个时间步将不会改变。

4.2.4.3. 遗忘机制 (Forgetting Mechanism) 遗忘对于学习至关重要,因为它有助于过滤掉不重要的临时信息,从而记忆更重要的信息。Memformer 引入了一种称为偏置记忆归一化 (Biased Memory Normalization, BMN) 的遗忘机制,专门为槽记忆表示设计。 在每个时间步,对记忆槽进行归一化,以防止记忆权重无限增长并保持梯度在长时间步内的稳定性。为了帮助遗忘旧信息,模型添加了一个可学习的偏置向量 vbiasv_{\mathrm{bias}} 到记忆槽中。此外,初始状态 m0im_0^i 自然地在归一化后等于 vbiasiv_{\mathrm{bias}}^i 的归一化。 mt+1imt+1i+vbiasimt+1imt+1imt+1im0ivbiasivbiasi \begin{array}{r} m_{t+1}^i \leftarrow m_{t+1}^{i}{'} + v_{\mathrm{bias}}^i \\ m_{t+1}^i \leftarrow \frac{m_{t+1}^i}{||m_{t+1}^i||} \\ m_0^i \leftarrow \frac{v_{\mathrm{bias}}^i}{||v_{\mathrm{bias}}^i||} \end{array} 符号解释:

  • mt+1im_{t+1}^{i}{'}: 从记忆槽注意力更新模块得到的第 ii 个记忆槽的候选值。

  • vbiasiv_{\mathrm{bias}}^i: 针对第 ii 个记忆槽的可学习偏置向量。

  • || \cdot ||: L2 范数,用于向量归一化。

  • mt+1im_{t+1}^i: 经过偏置添加和归一化后的第 ii 个记忆槽在时间步 t+1t+1 的最终表示。

  • m0im_0^i: 第 ii 个记忆槽的初始状态,由其可学习偏置向量的归一化版本决定。

    由于归一化,所有记忆槽都会被投影到球体分布上。vbiasv_{\mathrm{bias}} 控制着遗忘的速度和方向。当将 vbiasv_{\mathrm{bias}} 添加到记忆槽时,它会导致记忆在球体上移动并遗忘部分信息。如果一个记忆槽长时间未更新,它最终会达到“终态” (terminal state) TT,除非注入新信息。这个终态也是初始状态,并且是可学习的。遗忘的速度由 vbiasv_{\mathrm{bias}} 的大小以及 mt+1im_{t+1}^{i}{'}vbiasiv_{\mathrm{bias}}^i 之间的余弦距离控制。例如,如果 mbm_b 几乎与终态相反,那么它会很难遗忘其信息;而 mam_a 离终态更近,因此更容易遗忘。

下图(原文 Figure 4)展示了遗忘机制:

Figure 4: Illustration of forgetting. Memory slot `m _ { a }` is easy to be forgotten, while `m _ { b }` is hard to be forgotten. 该图像是示意图,展示了记忆槽的遗忘现象。图中标示了两个记忆状态,mam_ambm_b,其中 mam_a 容易被遗忘,而 mbm_b 则难以被遗忘。终态 TT 和其他状态 mm 通过连线展示了记忆之间的关系和遗忘程度。箭头表示记忆状态向终态的偏移,提示这些状态对于长时间维持信息的能力。

4.2.5. 记忆回放反向传播 (Memory Replay Back-Propagation, MRBP)

Memformer 依赖外部记忆处理序列,推理时由于固定大小的记忆设计,没有额外的内存成本。然而,训练时需要通过时间反向传播 (BPTT) 来训练记忆写入网络以保留长期信息。传统 BPTT 的问题在于它在正向传播过程中展开整个计算图并存储所有中间激活值,导致 Memformer 这种模型出现不切实际的巨大内存消耗。

为了解决这个问题,作者提出了一种更高效的梯度检查点 (gradient checkpointing) 变体:记忆回放反向传播 (MRBP)。MRBP 通过在每个时间步重放记忆来完成长距离反向传播。

算法 1: 记忆回放反向传播

Input: rollout = [x_t, x_{t+1}, ..., x_T]: 一个包含先前输入的列表
       Ω = [M_t, M_{t+1}, ..., M_T]: 先前的记忆 (如果已经计算)

-> 初始化一个用于反向传播的列表
1 replayBuffer = [M_t]

-> 正向传播 & 无梯度
2 for t = t, t+1, ..., T-1 do
3   M_{t+1}, _ = Model(x_t, M_t)
4   replayBuffer.append(M_{t+1})
5 end

-> 反向传播 & 有梯度
6 ∇M_{t+1} = 0  // 初始化下一时间步的记忆梯度
7 for t = T, T-1, ..., t+1, t do
8   M_{t+1}, O_t = Model(x_t, M_t)  // 重新计算当前时间步的模型输出和记忆
9   loss = f_loss(O_t)              // 计算当前时间步的损失
10  loss.backward()                 // 对当前时间步的损失进行反向传播
11  M_{t+1}.backward(∇M_{t+1})      // 对下一时间步的记忆梯度进行反向传播
12  ∇M_{t+1} = ∇M_t                 // 更新记忆梯度,准备传给前一时间步
13 end

-> 更新并弹出最旧的记忆
14 memories = replayBuffer
15 memories.pop()

符号解释:

  • rollout: 当前批次中,从时间步 ttTT 的输入分段序列列表。

  • ΩΩ: 存储的从时间步 ttTT 的记忆列表。

  • replayBuffer: 存储每个时间步计算出的记忆 Mt,Mt+1,,MTM_t, M_{t+1}, \dots, M_T 的列表。

  • Model(xt,Mt)Model(x_t, M_t): Memformer 模型在给定输入 xtx_t 和前一时间步记忆 MtM_t 的情况下,输出当前时间步的记忆 Mt+1M_{t+1} 和模型输出 OtO_t

  • _: 表示模型输出中不需要存储的部分。

  • Mt+1∇M_{t+1}: 记忆 Mt+1M_{t+1} 的梯度。

  • floss(Ot)f_loss(O_t): 损失函数,根据模型输出 OtO_t 计算损失。

  • loss.backward(): 计算当前时间步的损失相对于模型参数和输入 OtO_t 的梯度。

  • Mt+1.backward(Mt+1)M_{t+1}.backward(∇M_{t+1}): 将下一时间步的记忆梯度 Mt+1∇M_{t+1} 反向传播到当前时间步的记忆 MtM_t

    MRBP 算法在正向传播时只遍历计算图的关键路径,并存储每个时间步的记忆到 replayBuffer 中。在反向传播时,它会为每个局部时间步重新计算部分计算图,然后利用存储的记忆进行梯度回传。这种方法显著减少了训练时的内存消耗,同时保持了梯度流,允许在长序列上进行有效的学习。

5. 实验设置

5.1. 数据集

5.1.1. MNIST

  • 来源、规模、特点和领域: MNIST (LeCun and Cortes, 2010) 是一个广泛使用的手写数字图像数据集。它包含 60,000 张训练图像和 10,000 张测试图像。每张图像是 28×2828 \times 28 像素的灰度图像,每个像素的灰度值在 0 到 255 之间(8 位)。
  • 数据形态: 在本实验中,每张 28×2828 \times 28 的图像被重塑为 784 个词元 (tokens) 的序列。8 位灰度值被转换为 256 个词汇大小 (vocabulary size) 的离散值。
  • 选择原因: 尽管 MNIST 是一个相对简单的数据集,但它常被用于自回归图像生成任务,可以有效地验证模型在序列建模方面(将图像视为长序列)的性能和效率。

5.1.2. WikiText-103

  • 来源、规模、特点和领域: WikiText-103 (Merity et al., 2017) 是一个用于长距离语言建模 (long-range language modeling) 的基准数据集。它包含来自维基百科的 28,000 篇文章,每篇文章的平均长度为 3,600 个词元。
  • 选择原因: 该数据集因其长文本特性而被选中,非常适合评估模型在处理长距离依赖方面的能力,以及在真实语言建模任务中的性能。

5.2. 评估指标

5.2.1. 困惑度 (Perplexity, PPL)

  1. 概念定义: 困惑度是评估语言模型性能的常用指标,它量化了模型对一个给定序列的预测不确定性。困惑度越低,表示模型对序列的预测能力越强,模型认为该序列出现的可能性越大。直观上,困惑度可以理解为模型在预测下一个词时“平均需要猜测多少个词”。
  2. 数学公式: PPL(W)=P(w1,w2,,wN)1N=(i=1NP(wiw1,,wi1))1N \mathrm{PPL}(W) = P(w_1, w_2, \ldots, w_N)^{-\frac{1}{N}} = \left( \prod_{i=1}^N P(w_i | w_1, \ldots, w_{i-1}) \right)^{-\frac{1}{N}}
  3. 符号解释:
    • PPL(W)\mathrm{PPL}(W): 衡量整个序列 WW 的困惑度。
    • WW: 一个包含 NN 个词元 (tokens) 的序列,W=(w1,w2,,wN)W = (w_1, w_2, \ldots, w_N)
    • NN: 序列 WW 中的词元总数。
    • P(w1,w2,,wN)P(w_1, w_2, \ldots, w_N): 序列 WW 出现的联合概率。
    • P(wiw1,,wi1)P(w_i | w_1, \ldots, w_{i-1}): 在给定前 i-1 个词元的情况下,第 ii 个词元 wiw_i 出现的条件概率。

5.2.2. FLOPs (Floating-point Operations)

  1. 概念定义: FLOPs (Floating-point Operations) 是指浮点运算次数,用于衡量一个模型的计算复杂度或计算量。它表示模型在执行推理或训练过程中所需的浮点数学运算的总数。FLOPs 越少,通常意味着模型的计算效率越高,推理或训练速度越快。
  2. 数学公式: FLOPs 没有一个统一的、简单的数学公式,因为它取决于具体的模型架构和操作。它通常通过计算模型中所有矩阵乘法、卷积、激活函数等操作涉及的浮点乘法和加法次数来累加。
  3. 符号解释:
    • 通常以“GFLOPs” (Giga-FLOPs, 十亿次浮点运算) 或“TFLOPs” (Tera-FLOPs, 万亿次浮点运算) 为单位报告。
    • 在论文中,作者报告的是“BFLOPs” (Billion-FLOPs, 十亿次浮点运算)。

5.2.3. GPU 内存 (GPU Memory)

  1. 概念定义: GPU 内存是指模型在运行时(无论是训练还是推理)在图形处理单元 (GPU) 上消耗的显存量。它包括模型参数、中间激活值、梯度以及其他数据结构所需的存储空间。内存消耗越少,模型在资源受限的环境下(例如,单 GPU 或较小显存的 GPU)运行的可能性越大。
  2. 数学公式: 没有一个通用的数学公式,因为它取决于模型参数量、激活值大小(与批次大小、序列长度、隐藏维度等相关)、优化器状态等。
  3. 符号解释:
    • 通常以兆字节 (MB) 或千兆字节 (GB) 为单位报告。

5.3. 对比基线

论文将 Memformer 的性能与以下模型进行了比较:

5.3.1. 自回归图像生成任务 (MNIST)

  • LSTM: 具有 4 层和 512 隐藏层大小。这是经典的循环神经网络,作为序列建模的基础基线。
  • Transformer Decoder: 具有 8 层,可以直接将所有 784 个词元作为输入。这是标准 Transformer 架构在解码任务中的应用,可以作为一个强基线来比较非循环 Transformer 的性能。
  • Transformer-XL: 具有 8 层,并使用不同的记忆大小 (memory=56, 224, 784)。这是 Transformer-XL,一个具有循环和记忆机制的模型,旨在处理长上下文,作为 Memformer 的直接竞争对手。

5.3.2. 语言建模任务 (WikiText-103)

  • Transformer-XL base: 具有 16 层,512 隐藏层大小,2048 前馈层大小,64 头大小,8 头注意力。在不同的记忆大小下进行测试 (memory=32, 128, 256, 512, 1024, 1600)。这是一种强大的、专为长文本建模设计的基线。

  • Compressive Transformer: 记忆长度为 512,压缩记忆长度为 512,压缩比为 4。这是 Transformer-XL 的改进版,旨在通过压缩记忆进一步延长上下文。

    所有模型在各自任务中都保持了相同的隐藏层大小 (128 for MNIST, 512 for WikiText-103)、注意力头数 (4 for MNIST, 8 for WikiText-103)、头大小 (32 for MNIST, 64 for WikiText-103) 和前馈层大小 (256 for MNIST, 2048 for WikiText-103),以确保公平比较。

6. 实验结果与分析

6.1. 核心结果分析

6.1.1. 计算和内存成本比较

下图(原文 Figure 5)比较了 Vanilla Transformer、Transformer-XL 和 Memformer 在 FLOPs 数量和 GPU 内存消耗方面的表现:

Figure 5: Comparison of the number of FLOPs and GPU memory consumption for Vanilla Transformer Transformer-XL, and Memformer. 该图像是图表,展示了Vanilla Transformer、Transformer XL和Memformer在随着序列长度增加时的计算量(FLOPs)和GPU内存消耗的比较。左侧图表显示计算量随着序列长度的变化趋势,右侧图表则展示不同内存大小下的GPU内存消耗对比。Memformer在内存消耗上表现出明显优势。

  • 计算成本 (FLOPs): 左图展示了 FLOPs 随序列长度(从 128 增加到 8192)的变化。
    • Vanilla Transformer: 随着序列长度的增加,FLOPs 呈二次方增长,成本最高。这是由于其 O(N2)\mathcal{O}(N^2) 的自注意力复杂度。
    • Transformer-XL 和 Memformer: 两者都实现了 FLOPs 随序列长度的线性增长 (O(N)\mathcal{O}(N)),因为它们使用记忆存储历史信息,输入序列长度保持常数。Memformer 在这里展现出比 Transformer-XL 更好的效率。
  • GPU 内存消耗: 右图展示了 GPU 内存消耗随记忆大小(从 64 增加到 2048)的变化,批次大小为 16。
    • Transformer-XL: 内存消耗随记忆大小迅速增长,因为 Transformer-XL 存储所有层的过去隐藏状态作为记忆,其成本为 O(K×L)\mathcal{O}(K \times L)KK 是记忆大小,LL 是层数)。

    • Memformer: 内存消耗增长速度慢得多,因为它只存储 KK 个向量作为记忆,成本为 O(K)\mathcal{O}(K)。在大的记忆设置下,Memformer 使用的内存空间比 Transformer-XL 少 8.1 倍。

      这些结果强有力地验证了 Memformer 在处理长序列时,无论是在计算速度还是内存效率方面,都具有显著的优势。

6.1.2. 自回归图像生成结果

以下是原文 Table 1 中自回归图像生成的结果:

Model #FLOPs (B) Perplexity ↓
LSTM 52.5 1.698
Transformer Decoder 41.3 1.569
Transformer-XL
memory=56 5.6 1.650
memory=224 15.6 1.618
memory=784 49.1 1.611
Memformer 4 encoder+8 decoder 5.0 1.555
Memformer Ablation 2 encoder+6 decoder
memory=64 3.9 1.594
memory=32 3.9 1.600
memory=16 3.9 1.604
memory=1 3.9 1.627
4 encoder+4 decoder 3.6 1.628
w/o memory 1.8 1.745
temperature=1.0 3.9 1.612
w/o forgetting 3.9 1.630
w/o multi-head 3.9 1.626

核心结果分析:

  • 最佳性能与效率: Memformer (4 编码器 + 8 解码器) 达到了最佳的困惑度 (1.555),同时仅使用了 5.0 B 的 FLOPs。这比最佳的 Transformer-XL (memory=784) 的性能 (1.611) 更好,而 FLOPs 仅为其 10% 左右 (5.0 vs 49.1 B)。这表明 Memformer 在图像生成任务中具有卓越的效率和竞争力。
  • 超越 Transformer Decoder: Memformer 甚至超越了处理整个输入序列的 Transformer Decoder (1.569 vs 1.555),尽管 Transformer Decoder 的 FLOPs (41.3 B) 远高于 Memformer。这可能归因于 Memformer 额外的编码器层,能够更好地提取信息。
  • 记忆的重要性: w/o memory (无记忆) 的 Memformer 性能急剧下降 (1.745),甚至比 LSTM 还差。这强调了外部动态记忆在 Memformer 模型中的关键作用。

6.1.3. 语言建模结果

以下是原文 Table 2 中语言建模的结果:

Model #FLOPs (B) PPL ↓
Transformer-XL base
memory=1600 250 23.95
memory=1024 168 23.67
memory=512 94 23.94
memory=256 58 25.39
memory=128 39 25.60
memory=32 Compressive Transformer 26 27.22
memory= 512 compress=512 Memformer 172 23.23
4 encoder + 16 decoder 54 22.74
Memformer Ablation
4 encoder + 12 decoder 48 23.91
memory=512 35 23.30
w/o memory 31 25.57

核心结果分析:

  • 最佳性能与效率: Memformer (4 编码器 + 16 解码器) 在 WikiText-103 上取得了最佳的困惑度 (22.74),同时仅使用了 54 B 的 FLOPs。这不仅性能优于所有 Transformer-XL 基线,而且计算成本远低于最佳 Transformer-XL (memory=1024, PPL=23.67, FLOPs=168 B) 和 Compressive Transformer (PPL=23.23, FLOPs=172 B)。这表明 Memformer 在语言建模任务中也能提供更高效的记忆表示。
  • Transformer-XL 的记忆与性能: 随着 Transformer-XL 记忆大小的增加,困惑度通常会下降,但 FLOPs 迅速增加。当记忆大小增加到 1600 时,性能没有进一步提升,这可能因为过大的记忆带来了噪声,或者已经达到了数据集的平均文章长度 (3600 词元) 所能受益的最大记忆范围。
  • Compressive Transformer 比较: Compressive Transformer (PPL=23.23) 性能略优于 Transformer-XL (memory=1024),但 FLOPs (172 B) 仍然很高。Memformer 以显著更低的 FLOPs (54 B) 取得了更好的性能 (PPL=22.74)。

6.2. 消融实验/参数分析

6.2.1. 自回归图像生成任务的消融实验

从 Table 1 的 Memformer Ablation 部分可以看出:

  • 编码器-解码器层数: Memformer (4 编码器 + 4 解码器) 的性能 (1.628) 比 (4 编码器 + 8 解码器) 的性能 (1.555) 有所下降,但 FLOPs 减少 (3.6 B vs 5.0 B)。这表明解码器层数对最终性能很重要。
  • 记忆大小: 随着记忆大小从 64 减少到 1 (在 2 编码器 + 6 解码器配置下),困惑度从 1.594 逐渐恶化到 1.627。当 memory=1memory=1 时,性能明显下降。这再次证明了记忆大小对模型性能的关键影响,更大的记忆有助于保留更多信息。
  • 遗忘机制: w/o forgetting (无遗忘机制) 的性能 (1.630) 比默认配置 (1.594) 差,表明遗忘机制对模型的有效性有所贡献,有助于过滤不重要信息。
  • 记忆写入温度:temperature=1.0temperature=1.0 (默认是 0.25) 时,性能下降到 1.612,表明较低的温度可以使注意力分布更尖锐,从而更好地聚焦于写入信息。
  • 多头注意力: w/o multi-head (无多头注意力) 的性能 (1.626) 同样低于默认配置,强调了多头注意力在记忆写入中的重要性。

6.2.2. 语言建模任务的消融实验

从 Table 2 的 Memformer Ablation 部分可以看出:

  • 编码器-解码器层数: Memformer (4 编码器 + 12 解码器) 的性能 (23.91) 略低于 (4 编码器 + 16 解码器) 的性能 (22.74),但 FLOPs 更低 (48 B vs 54 B)。这进一步支持了解码器层数对性能的重要性。
  • 记忆大小: 当记忆大小从 1024 减少到 512 时 (在 4 编码器 + 12 解码器配置下),性能从 23.91 提升到 23.30。这似乎与图像生成的结果(记忆越大越好)略有不同。这可能是因为在语言建模任务中,适当的记忆大小能够捕获关键信息,而过大的记忆也可能引入噪声,需要权衡。
  • 无记忆: 当完全移除记忆模块 (w/o memory) 时,困惑度急剧增加到 25.57,与 Transformer-XL 记忆大小为 128 时的性能相似。这再次凸显了记忆模块在长距离语言建模中的核心作用。

6.2.3. MRBP 效率测试

以下是原文 Table 3 中记忆回放反向传播 (MRBP) 的性能比较:

Method GPU Memory (MB) Speed (relative)
BPTT 16,177 x1.00
GC 9,885 x0.48
MRBP 7,229 x0.90

分析:

  • 内存效率: MRBP (7,229 MB) 比标准 BPTT (16,177 MB) 节省了大量的 GPU 内存,甚至比梯度检查点 (GC) (9,885 MB) 节省更多。
  • 速度: 尽管梯度检查点 (GC) 节省了内存,但其速度只有 BPTT 的 0.48 倍。相比之下,MRBP 在内存节省的同时,速度退化很小 (0.90 倍),这使其成为训练带有大记忆的循环神经网络的有效解决方案。

6.2.4. 时间范围和记忆大小的影响

下图(原文 Figure 8)展示了回传时间范围和记忆大小对困惑度的影响:

该图像是一个图表,展示了不同时间跨度和内存大小对困惑度的影响。左图显示了在不同回传时间跨度下困惑度的变化,右图则展示了不同内存大小对困惑度的影响。 该图像是一个图表,展示了不同时间跨度和内存大小对困惑度的影响。左图显示了在不同回传时间跨度下困惑度的变化,右图则展示了不同内存大小对困惑度的影响。

  • 回传时间范围: 左图展示了回传时间范围从 1 到 32 对困惑度的影响。当时间范围为 1 时,性能最差,因为梯度无法通过记忆传递到前一个时间步。随着时间范围的增加,模型的困惑度降低,性能提升。当时间范围增加到 32 时,困惑度的边际改进趋于消失。这表明存在一个最佳的回传时间范围,以平衡性能和计算成本。
  • 记忆大小: 右图展示了记忆大小从 1 到 64 对困惑度的影响。将记忆大小从 1 增加到 8 带来了显著的性能提升。然而,进一步增加记忆大小对性能的改善效果变小,这可能是由于模型规模的限制,或者在特定任务中,某个大小的记忆已经足够捕获主要信息。

6.3. 记忆写入器分析 (Memory Writer Analysis)

作者分析了记忆写入器 (memory writer) 的注意力输出,并将记忆槽大致分为三种类型。下图(原文 Figure 6)可视化了三种类型的记忆槽的归一化注意力值:

Figure 6: Visualization of three types of memory slots. 该图像是一个热图,展示了三种类型的记忆槽(m250m^{250}m300m^{300}m355m^{355})与一系列词汇之间的关联强度。每个单元格内的数值表明特定记忆槽与对应词汇的相关性,例如,记忆槽m300m^{300}与词汇“the”之间的关联强度为0.92,显示出其显著性。通过不同颜色的深浅,图像有效地直观展示了信息存储与检索的模式。

  • 第一种类型 (例如 m300m^{300}): 在处理文档的中途,大约 60% 到 80% 的记忆槽属于这一类型。它们的注意力集中在自身,这意味着它们在当前时间步没有被更新。这表明这些记忆槽能够携带来自遥远过去的信息。例如,m300m^{300} 对“the”的注意力达到 0.92。

  • 第二种类型 (例如 m250m^{250}): 这种记忆槽对自己保留部分注意力,其余注意力则分布在其他词元上。这种类型的记忆槽是从第一种类型转换而来,在当前时间步聚合了来自其他词元的信息。例如,m250m^{250} 对“the”的注意力为 0.28,对“that”为 0.22。

  • 第三种类型 (例如 m355m^{355}): 这种记忆槽完全关注输入词元。在开始时,几乎所有记忆槽都属于这种类型,但后来仅占总记忆槽的 5% 到 10%。作者还发现 m355m^{355} 的遗忘向量偏置 (bias) 具有更大的幅度 (3.20),而其他一些槽的幅度较小 (1.15),这表明该记忆槽中的信息变化迅速。

    下图(原文 Figure 7)可视化了记忆槽 m355m^{355} 在一个示例输入序列上的注意力模式:

    Figure 7: Visualization of the memory writer's attention. 该图像是一个展示记忆写入者注意力的示意图,强调了文本中不同词汇的关注程度,突出显示了关键内容如‘volunteer’和‘quitting’等。这些高亮部分表明系统在处理序列时所吸引的注意力。

该图显示 m355m^{355} 通过关注一些命名实体 (named entities) 和动词 (verbs) 来学习句子的压缩表示,这与人类认知过程一致。这进一步证明了 Memformer 的外部记忆槽不仅能够存储信息,而且能够智能地更新和聚焦于关键信息。

6.4. 训练细节

以下是原文 Table 4 中训练细节:

Image Generation Language Modeling
batch size 256 128
warm-up steps 1,000 10,000
learning rate 1e-3 1e-3
dropout 0.1 0.1
memory length 8 1,024
temperature 0.25 0.125
time horizon 8 8
weight decay 0.01 0.01
max gradient norm 1.0 1.0
training steps 10,000 150,000

训练在 NVIDIA V100 16GB 和 2080Ti 11GB GPU 上进行。图像生成训练大约需要一天时间在一块 GPU 上完成。语言建模训练大约需要四天时间在四块 GPU 上完成。

7. 总结与思考

7.1. 结论总结

本文提出了 Memformer,一种利用外部动态记忆高效处理长序列的自回归模型。Memformer 实现了线性时间复杂度和常数内存空间复杂度,有效解决了传统 Transformer 在处理长序列时的效率瓶颈。模型的核心在于其外部动态记忆槽、基于注意力机制的记忆读取和写入模块,以及一个独特的偏置记忆归一化 (BMN) 遗忘机制。为了实现高效训练,本文还引入了记忆回放反向传播 (MRBP) 优化方案,显著降低了训练大型记忆型循环神经网络所需的内存成本,同时保持了接近标准 BPTT 的训练速度。

实验结果表明,Memformer 在自回归图像生成和语言建模任务上取得了与当前最先进基线模型(如 Transformer-XL)相当甚至更优的性能,同时在推理时实现了 8.1 倍的内存节省和 3.2 倍的速度提升。对记忆写入器注意力模式的分析证实,Memformer 的外部记忆槽能够有效地编码和保留来自遥远过去的重要信息,显示出其强大的长期记忆能力。

7.2. 局限性与未来工作

局限性:

  1. 分段处理的粒度: 尽管分段级序列建模提高了效率,但如何最优地划分分段以及分段长度 LL 的选择可能会影响模型捕获局部和全局依赖的能力。
  2. 记忆槽数量 kk 的选择: 外部动态记忆槽的数量 kk 是一个重要的超参数。实验表明增加 kk 可以提升性能,但边际效益会递减,且对 kk 的最优选择可能依赖于具体的任务和数据集。
  3. MRBP 的计算开销: 尽管 MRBP 显著减少了内存,但通过重计算来获得梯度仍然会带来一定的计算开销(相对标准 BPTT 速度为 0.90 倍),这在追求极致训练速度的场景下仍需考虑。
  4. 遗忘机制的复杂性: 偏置记忆归一化 (BMN) 是一种启发式方法,其学习到的偏置向量 vbiasv_{\mathrm{bias}} 如何更精细地控制遗忘过程,以及其在更复杂数据集上的通用性,仍有待深入研究。

未来工作: 作者相信 Memformer 增强的记忆能力可以启发依赖循环和自回归建模的有趣工作,特别是在对话系统 (dialog systems) 和交互系统 (interactive systems) 等任务中。这些任务需要模型能够长时间地记住上下文信息并进行连贯的交互,而 Memformer 的高效长期记忆机制有望在此类任务中发挥重要作用。

7.3. 个人启发与批判

个人启发:

  1. 记忆与效率的平衡: Memformer 的设计思路提供了一个极佳的范例,即如何在保持 Transformer 强大建模能力的同时,通过引入外部记忆系统来突破其在长序列处理上的效率瓶颈。这启发我们,对于计算成本高昂的模型,可以从“如何存储和复用信息”的角度寻找优化方案。
  2. 高层次记忆的潜力: 相比 Transformer-XL 存储原始隐藏状态,Memformer 存储高层次的、压缩的记忆,这体现了人类记忆的某些特性:我们记忆的是抽象的概念和关键事件,而非所有细节。这种高层次记忆表示在效率和信息保留之间找到了更好的平衡。
  3. MRBP 的普适性: 记忆回放反向传播 (MRBP) 算法不仅适用于 Memformer,对于任何需要通过时间反向传播且涉及巨大内存消耗的循环网络或带有记忆的深度学习模型,都具有普适的参考价值。它在内存和速度之间的权衡是实践中非常宝贵的经验。
  4. 组件设计的精妙: 记忆读取的交叉注意力、记忆写入的槽注意力以及偏置记忆归一化遗忘机制,每一个组件都经过精心设计,共同构成了高效且有效的记忆管理系统。特别是温度参数在注意力中的应用,可以精细控制信息写入的聚焦程度。

批判:

  1. 记忆槽内容的解释性: 尽管论文通过可视化分析了记忆槽的注意力模式,并将其分为三类,但记忆槽内部具体编码了哪些语义信息(例如,它们是否形成了某种概念集群或主题表示)的深层解释性仍有待加强。更深入的分析方法(如探究性分析、线性探针等)可能会揭示更多信息。
  2. 对噪声和冗余记忆的处理: 论文提到 WikiText-103 上记忆大小过大(如 1600)可能带来噪声,导致性能不升反降。虽然遗忘机制旨在过滤不重要信息,但如何更鲁棒地处理冗余或冲突的记忆,防止其对模型性能产生负面影响,是未来可以探索的方向。
  3. 通用性与复杂性权衡: Memformer 引入了编码器-解码器架构、外部记忆、读取写入模块、遗忘机制以及 MRBP 训练方案,模型的整体复杂性相比标准 Transformer 有所增加。虽然其效率优势显著,但在特定简单任务上,这种增加的复杂性是否总是必要的,以及如何简化或自动化组件设计,值得进一步思考。例如,记忆槽的数量和温度参数等超参数的调优成本。
  4. 长序列建模的极限: 尽管 Memformer 理论上具有无限记忆范围,但实际的“时间范围” (time horizon) 参数仍然限制了梯度回传的长度。在处理极其长的序列(例如,数百万词元)时,记忆的压缩能力和其在“忘记”旧信息方面是否能保持足够的精确性,仍是一个挑战。

相似论文推荐

基于向量语义检索推荐的相关论文。

暂时没有找到相似论文。