Large Language Diffusion Models
TL;DR 精炼摘要
本文提出LLaDA,一种基于扩散模型的大型语言模型,通过前向掩码和逆向生成用Transformer预测词元,优化似然下界,实现概率推断。在多任务与上下文学习中表现优异,突破自回归模型限制,展现扩散模型在大规模语言建模中的潜力。
摘要
The capabilities of large language models (LLMs) are widely regarded as relying on autoregressive models (ARMs). We challenge this notion by introducing LLaDA, a diffusion model trained from scratch under the pre-training and supervised fine-tuning (SFT) paradigm. LLaDA employs a forward data masking process and a reverse generation process, parameterized by a Transformer to predict masked tokens. It provides a principled generative approach for probabilistic inference by optimizing a likelihood lower bound. Across extensive benchmarks on general tasks, math, code, and so on, LLaDA demonstrates strong scalability and performs comparably to our self-constructed ARM baselines. Remarkably, LLaDA 8B is competitive with strong LLMs like LLaMA3 8B in in-context learning and, after SFT, exhibits impressive instruction-following abilities in case studies such as multi-turn dialogue. Moreover, LLaDA addresses the reversal curse, surpassing GPT-4o in a reversal poem completion task. Our findings show the promise of diffusion models for language modeling at scale and challenge the common assumption that core LLM capabilities discussed above inherently depend on ARMs. Project page and codes: https://ml-gsai.github.io/LLaDA-demo/.
思维导图
论文精读
中文精读
1. 论文基本信息
1.1. 标题
大型语言扩散模型 (Large Language Diffusion Models)
1.2. 作者
Shen Nie, Fengqi Zhu, Zebin You, Xiaolu Zhang, Jingyang Ou, Jun Hu, Jun Zhou, Yankai Lin, Ji-Rong Wen, Chongxuan Li 等。 主要作者隶属于中国人民大学高瓴人工智能学院 (Gaoling School of Artificial Intelligence, Renmin University of China) 和教育部下一代智能搜索与推荐工程研究中心 (Engineering Research Center of Next-Generation Intelligent Search and Recommendation, MOE)。部分作者来自蚂蚁集团 (Ant Group)。
1.3. 发表期刊/会议
本文作为预印本 (preprint) 发布在 arXiv 上。
1.4. 发表年份
2025年。根据原文链接,发布于 UTC 时间 2025年2月14日。
1.5. 摘要
当前大型语言模型 (Large Language Models, LLMs) 的能力普遍被认为依赖于自回归模型 (Autoregressive Models, ARMs)。本文通过引入 LLaDA (Large Language Diffusion with mAsking) 对这一观点提出了挑战。LLaDA 是一种扩散模型 (Diffusion Model),它在预训练 (Pre-training) 和监督微调 (Supervised Fine-Tuning, SFT) 范式下从头开始训练。LLaDA 采用前向数据掩码过程 (forward data masking process) 和逆向生成过程 (reverse generation process),并通过一个 Transformer 模型来预测被掩码的词元 (token)。它通过优化似然下界 (likelihood lower bound) 提供了一种基于原理的概率推理生成方法。
在通用任务、数学、代码等广泛基准测试中,LLaDA 展示了强大的可扩展性 (scalability),并与作者自行构建的 ARM 基线模型 (baselines) 表现相当。值得注意的是,LLaDA 8B 在上下文学习 (in-context learning) 方面与强大的 LLM,如 LLaMA3 8B 具有竞争力,并且在 SFT 后,在多轮对话 (multi-turn dialogue) 等案例研究中展现出令人印象深刻的指令遵循 (instruction-following) 能力。此外,LLaDA 解决了反转诅咒 (reversal curse) 问题,在反转诗歌补全任务中超越了 GPT-4o。研究结果表明,扩散模型在大规模语言建模方面具有巨大潜力,并挑战了 LLM 核心能力本质上依赖于 ARM 的普遍假设。
1.6. 原文链接
https://arxiv.org/abs/2502.09992 PDF 链接: https://arxiv.org/pdf/2502.09992v3.pdf
2. 整体概括
2.1. 研究背景与动机
论文试图解决的核心问题是什么?
论文的核心问题是:大型语言模型 (LLMs) 的核心能力(如可扩展性、上下文学习和指令遵循)是否必然依赖于自回归模型 (Autoregressive Models, ARMs)?目前,LLMs 几乎完全基于 ARMs 的“下一词元预测 (next-token prediction)”范式。
为什么这个问题在当前领域是重要的?现有研究存在哪些具体的挑战或空白(Gap)?
- 单一范式限制: 尽管 ARMs 取得了巨大成功,但过度依赖单一范式可能限制了对 LLM 潜力的全面探索。
- ARM 的固有限制: ARMs 的左到右生成特性导致了一些固有限制,例如在处理反转推理任务 (reversal reasoning tasks) 时的“反转诅咒 (reversal curse)”现象。它们难以对序列进行双向建模。
- 扩散模型的潜力: 扩散模型 (Diffusion Models) 在视觉领域取得了显著成功,但在大规模语言建模方面尚未得到充分验证。现有的离散扩散模型在可扩展性和理论基础方面仍有不足,或者未能达到与主流 LLMs 相当的性能。
- 理论空白: 缺乏一个基于原理的、可扩展的离散扩散模型,能够像 ARMs 一样通过最大似然估计 (maximum likelihood estimation) 来优化语言分布,并展现出类似的 LLM 核心能力。
这篇论文的切入点或创新思路是什么?
本文的创新思路是:通过构建和训练一个大规模的离散掩码扩散模型 (Masked Diffusion Model, MDM),即 LLaDA,来证明 LLMs 的核心能力并非 ARMs 所独有,而是根植于更普遍的生成建模原则。LLaDA 的切入点在于:
- 理论基础: 采用基于最大似然下界优化的离散扩散模型,使其具有坚实的生成建模理论基础,而非启发式方法。
- 双向建模: 利用掩码扩散的固有特性,实现对序列的双向依赖建模,以克服 ARM 的单向限制。
- 大规模验证: 将离散扩散模型首次扩展到 80 亿参数的规模,并在与主流 LLMs 相似的数据和计算预算下进行预训练和监督微调,以直接与 ARMs 进行性能比较。
2.2. 核心贡献/主要发现
论文最主要的贡献是什么?
- 引入 LLaDA: 提出了 LLaDA,一个从零开始训练的 80 亿参数级别的掩码扩散模型,它遵循与主流 LLMs 相同的预训练和监督微调范式。
- 挑战 ARM 霸权: 实验证明,扩散模型可以在大规模语言建模中实现与自回归模型相当甚至在某些方面更优的性能,从而挑战了 LLM 核心能力必然依赖于 ARM 的普遍假设。
- 强大的可扩展性与性能:
- LLaDA 展示了强大的可扩展性,在通用任务、数学、代码等多个基准测试中,其性能与作者自建的 ARM 基线模型相当。
- LLaDA 8B 在上下文学习方面与 LLaMA3 8B 具有竞争力。
- 在监督微调后,LLaDA 展现了出色的指令遵循能力,特别是在多轮对话中。
- 解决反转诅咒: LLaDA 有效解决了传统 LLMs 的“反转诅咒”问题,在反转诗歌补全任务中表现出色,甚至超越了 GPT-4o。
论文得出了哪些关键的结论或发现?这些发现解决了什么具体问题?
- 扩散模型的可行性与潜力: 扩散模型不仅在图像生成领域表现出色,在语言建模领域,尤其是在大规模设置下,同样具有巨大的潜力。这为语言建模提供了新的范式选择。
- LLM 能力的普适性: LLM 的核心能力(可扩展性、上下文学习、指令遵循)并非自回归模型所独有,而是生成建模原理的普遍结果。这解决了“这些能力是否必须通过自回归方式实现”的问题。
- 双向建模的优势: LLaDA 的双向建模特性使其在处理需要双向推理的任务(如反转任务)时表现出卓越的鲁棒性和一致性,克服了 ARM 在这些任务上的局限。
- 理论与实践结合: 基于似然下界优化的掩码扩散模型,在理论上具有生成建模的坚实基础,并在实践中证明了其在大规模语言任务上的竞争力。
3. 预备知识与相关工作
3.1. 基础概念
为了理解本文,需要掌握以下核心概念:
- 大型语言模型 (Large Language Models, LLMs): 这是一类参数量巨大(通常数十亿甚至数千亿)、在海量文本数据上训练的深度学习模型。它们能够理解、生成和处理人类语言,执行多种自然语言处理 (Natural Language Processing, NLP) 任务,如问答、翻译、摘要和代码生成等。LLMs 通常基于 Transformer 架构。
- 自回归模型 (Autoregressive Models, ARMs): 这是目前主流 LLMs 所采用的生成范式,也被称为“下一词元预测 (next-token prediction)”。一个自回归模型在生成序列中的第 个词元 (token) 时,会以前面已经生成的
1到i-1个词元作为条件。其核心特点是单向(通常是左到右)生成,并且一旦生成就不可更改。- 数学表示: 对于一个长度为 的序列 ,ARMs 将其联合概率分布建模为条件概率的乘积:
其中, 是模型参数为 时生成序列 的概率, 是第一个词元的概率, 是在已知前
i-1个词元的情况下生成第 个词元的条件概率。
- 数学表示: 对于一个长度为 的序列 ,ARMs 将其联合概率分布建模为条件概率的乘积:
其中, 是模型参数为 时生成序列 的概率, 是第一个词元的概率, 是在已知前
- 扩散模型 (Diffusion Models): 这是一类生成模型,最初在计算机视觉领域取得巨大成功。其核心思想是定义一个前向扩散过程 (forward diffusion process),逐步向数据中添加噪声(或进行掩码、扰动),直到数据完全变成噪声(或完全掩码);然后训练一个逆向去噪(或去掩码)过程 (reverse denoising/demasking process) 来学习如何从噪声(或掩码)中逐步恢复原始数据。通过逆向过程,模型可以从随机噪声(或完全掩码)生成全新的数据。
- 变分下界 (Variational Lower Bound, VLB): 在概率生成模型中,尤其是当直接优化数据的对数似然 (log-likelihood) 变得困难时,通常会转而优化一个其对数似然的下界。这个下界被称为变分下界 (VLB) 或证据下界 (Evidence Lower Bound, ELBO)。优化 VLB 能够间接地提高模型的对数似然,从而使模型更好地拟合数据分布。在扩散模型中,损失函数通常被设计为 VLB 的一个代理 (proxy)。
- 监督微调 (Supervised Fine-Tuning, SFT): 这是一种后训练 (post-training) 方法,用于增强预训练语言模型在特定任务上的表现,特别是使其更好地遵循人类指令。通过在高质量的指令-响应对数据集上进行训练,模型学习如何以有用的、无害的、诚实的方式响应用户提示。
- 上下文学习 (In-Context Learning, ICL): 指 LLM 在不更新模型参数的情况下,仅通过在输入提示 (prompt) 中提供少量示例就能学习并执行新任务的能力。模型从提示中的示例中推断出任务模式和所需的输出格式,并据此生成新的响应。
- 指令遵循 (Instruction-Following): 指模型理解并准确执行用户提供的自然语言指令的能力。这是 LLM 在实际应用中非常关键的一项能力,例如生成特定格式的文本、回答特定类型的问题等。
- 反转诅咒 (Reversal Curse): 这是 LLMs 中的一种现象,指模型在学习了“A是B”这样的事实后,往往难以回答“B是什么”这样的反转问题。例如,如果模型被训练成知道“Xavier 是 LLaMA3 团队的成员”,它可能无法正确回答“LLaMA3 团队的成员是谁?”。这凸显了 ARMs 在推理和泛化能力上的一个特定弱点。
- 词元 (token): 在自然语言处理中,词元是文本的最小有意义单元。它可以是一个词、一个子词 (subword)、一个字符,甚至是一个标点符号。LLMs 在内部操作时,将文本分解为词元序列。
- 掩码词元 (Mask Token, M): 在掩码语言模型 (Masked Language Models, MLMs) 或掩码扩散模型中,
[M]或[MASK]是一个特殊的词元,用于替换原始文本中的某个词元,表示该位置的词元是未知的,需要模型进行预测。
3.2. 前人工作
- BERT (Devlin et al., 2018):
BERT是最早提出并广泛使用掩码语言模型 (Masked Language Model) 的模型之一。它通过预测文本中被随机掩盖的词元来学习双向上下文表示。与LLaDA不同,BERT使用固定比例的掩码率。 - 早期离散扩散模型 (Austin et al., 2021; Lou et al., 2023):
Austin et al. [16]率先将离散扩散模型引入语言建模,证明了其可行性。Lou et al. [17]展示了掩码扩散作为离散扩散的特例,在GPT-2规模上可以达到或超越ARM的困惑度 (perplexity) 性能。这些工作为LLaDA奠定了基础。 - 掩码扩散的理论基础 (Shi et al., 2024; Sahoo et al., 2024; Ou et al., 2024):
Shi et al. [18]、Sahoo et al. [19]和Ou et al. [20]等研究建立了掩码扩散模型的理论基础,证明了其损失函数是负对数似然的一个上界,使得掩码扩散模型能够作为原理性的生成模型进行优化。这些理论成果是LLaDA模型设计、训练和推理的直接动机。 - 掩码扩散模型的扩展 (Nie et al., 2024):
Nie et al. [27]探索了语言建模中MDM的扩展定律 (scaling laws),并研究了MDM如何应用于问答等语言任务,但其规模仍停留在GPT-2级别。 MaskGIT(Chang et al., 2022):MaskGIT是图像生成领域的一个掩码生成模型,它使用启发式训练目标,并缺少与最大似然的理论联系。LLaDA明确指出MaskGIT损失函数中缺少 项,这使得LLaDA在理论上更具生成性。- Transformer (Vaswani et al., 2017):
Transformer架构是所有现代LLM和LLaDA的基础,通过自注意力 (self-attention) 机制实现了并行处理和长距离依赖建模。 - 扩展定律 (Scaling Laws) (Kaplan et al., 2020; Hoffmann et al., 2022): 这些研究揭示了模型性能、参数量、数据量和计算量之间的幂律关系,指导了
LLM的大规模训练。LLaDA的可扩展性分析也遵循了这些定律。 - 无分类器引导 (Classifier-Free Guidance, CFG) (Ho & Salimans, 2022):
CFG是一种在扩散模型中普遍使用的技术,通过结合有条件和无条件预测来提高生成质量和样本对提示的对齐度。LLaDA证明了其模型与CFG兼容。 - 块扩散 (Block Diffusion) (Arriola et al., 2025):
Block Diffusion是一种结合了自回归和扩散模型特性的采样策略,LLaDA也探索了这种采样方式。
3.3. 技术演进
领域的技术演变可以概括为从单向自回归到双向掩码理解,再到双向掩码生成,并最终实现大规模生成建模。
- 早期语言模型 (如 N-gram 模型): 基于统计学原理,通过计算词语序列的概率来预测下一个词。
- 循环神经网络 (Recurrent Neural Networks, RNNs) 及其变体 (如 LSTM, GRU): 引入了“记忆”机制,能够处理序列数据,但存在长距离依赖问题和并行化困难。它们通常以自回归方式工作。
- Transformer 架构 (2017): 通过引入自注意力 (self-attention) 机制,解决了
RNN的长距离依赖和并行化问题,成为现代语言模型的基石。- 自回归 Transformer (如 GPT 系列): 采用因果掩码 (causal mask),只能看到当前词元之前的信息,进行左到右的下一词元预测。这是当前主流 LLMs 的核心范式。
- 编码器-解码器 Transformer (如 T5): 编码器可以处理双向信息,解码器通常是自回归的。
- 双向 Transformer (如 BERT):
BERT的编码器部分通过掩码语言模型任务学习双向上下文信息,但它主要用于理解任务,而非生成任务。
- 扩散模型 (Diffusion Models) 的兴起 (2015年起,Sohl-Dickstein et al. 2015; Ho et al., 2020): 最初在计算机视觉领域取得突破,通过逐步加噪和去噪来生成数据。
- 连续扩散模型在文本上的尝试 (Li et al., 2022; Gong et al., 2022 等): 试图将连续扩散模型应用于文本,但文本的离散性带来了挑战,可扩展性受限。
- 离散扩散模型专门化 (Austin et al., 2021; Lou et al., 2023 等): 针对离散数据(如文本)设计了特殊的扩散过程(如掩码、替换),使得扩散模型更适用于语言建模。
- 大规模离散扩散模型的探索 (Nie et al., 2024): 开始探索离散扩散模型的扩展定律,但仍停留在较小规模。
- LLaDA 的突破: 本文将离散掩码扩散模型首次扩展到 80 亿参数这一前所未有的规模,并在与主流
ARM LLMs相当的条件下进行训练和评估,证明了扩散模型在大规模语言建模中能够展现出ARM LLMs的核心能力,并解决ARM的某些固有局限(如反转诅咒)。这代表了语言生成建模领域的一个重要进展,从单一的自回归范式,成功拓展到并验证了双向的扩散生成范式在大规模模型上的有效性。
3.4. 差异化分析
LLaDA 的方法与相关工作中的主要方法相比,核心区别和创新点在于:
-
与自回归模型 (ARMs) 的差异:
- 生成范式:
ARMs采用单向(通常是左到右)的“下一词元预测”范式,一次生成一个词元,并以之前生成的词元为条件。LLaDA采用基于掩码扩散的双向生成范式,通过逐步恢复被掩码的词元来生成整个序列,可以同时预测多个词元。 - 建模能力:
ARMs的单向特性导致它们在处理反转推理任务时存在“反转诅咒”等泛化限制。LLaDA的双向建模能力(可以同时考虑上下文的左右信息)使其在这些任务上表现出更强的鲁棒性。 - 并行性: 在训练阶段,
LLaDA的掩码预测任务允许并行处理序列中所有被掩码的词元,这与BERT类似。在推理阶段,扩散采样可以同时预测多个词元,提供了生成质量与速度之间的灵活权衡,而ARMs通常是严格串行生成。 - 理论基础: 尽管两者都旨在优化对数似然,但
ARMs直接优化对数似然,而LLaDA优化的是对数似然的变分下界,这在数学上是严谨的。
- 生成范式:
-
与早期离散扩散模型 (MDMs) 的差异:
- 规模:
LLaDA将掩码扩散模型扩展到前所未有的 80 亿参数规模,并在 2.3 万亿词元数据上进行训练,这远超现有离散扩散模型(如Nie et al. [27]仅达到GPT-2规模)。这是其最核心的差异点,证明了MDM在大规模下的可扩展性。 - 理论严谨性:
LLaDA严格遵循了Shi et al. [18]、Sahoo et al. [19]和Ou et al. [20]建立的理论基础,使用包含 项的损失函数,确保其优化的是最大似然的变分下界,而非启发式目标(如MaskGIT [23])。这种原理性设计对于其在大规模上的成功至关重要。 - 全面评估:
LLaDA在广泛的基准任务(通用、数学、代码、中文)上进行了系统评估,并直接与主流ARM LLMs进行比较,提供了更全面的性能验证。 - 能力展示:
LLaDA明确展示了大规模MDM能够具备ARM LLMs的核心能力,包括可扩展性、上下文学习和指令遵循,并解决了ARM的特定问题(反转诅咒),这是早期MDM所未能充分验证的。
- 规模:
4. 方法论
4.1. 方法原理
LLaDA 的核心思想是利用掩码扩散模型 (Masked Diffusion Model, MDM) 作为大型语言模型的基础。其原理可概括为:
-
生成建模原则:
LLaDA坚持最大化数据分布的对数似然(或等价地,最小化KL散度)的生成建模原则,而不是像自回归模型那样特指的“下一词元预测”形式。 -
前向掩码过程: 定义一个前向过程,逐步且独立地对输入序列中的词元进行掩码,直到序列完全被掩码。掩码的概率 随着时间从 0 增加到 1。
-
逆向去掩码过程: 训练一个“掩码预测器 (mask predictor)”(一个
Transformer模型),它学习如何从部分掩码的序列中,同时预测所有被掩码的原始词元。这个预测器用于模拟逆向过程,从完全掩码的序列逐步恢复出原始数据。 -
优化似然下界:
LLaDA的训练目标是优化数据分布对数似然的一个变分下界,这使得它是一个原理性的生成模型,能够进行概率推理。 -
双向建模: 由于掩码扩散模型在预测被掩码词元时可以同时考虑序列中所有未被掩码的词元(无论是左侧还是右侧),这使得
LLaDA天然具备双向建模的能力,克服了自回归模型单向生成带来的限制。这种设计使得
LLaDA能够像LLM一样实现可扩展性、上下文学习和指令遵循等能力,同时还带来了处理反转任务等特定优势。
4.2. 核心方法详解
4.2.1. 概率公式
LLaDA 定义模型分布 的方式不同于自回归模型 (ARMs) 的公式 (2)。它通过一个前向过程 (forward process) 和一个逆向过程 (reverse process) 来实现。
前向过程:
前向过程 由时间 索引。这个过程会逐渐且独立地掩码序列 x _ { 0 } 中的所有词元。在时间 时,x _ { 0 } 是完全观测的,没有掩码。当 时,x _ { t } 表示具有不同掩码比率的潜在变量。
形式上,给定 x _ { 0 } 的 x _ { t } 条件分布定义为一个完全分解的形式:
其中, 是序列长度,对于每个词元 ,其条件分布为:
这里:
x _ { 0 }代表原始的、未掩码的序列。x _ { t }代表在时间步 时被部分掩码的序列。- 是原始序列中第 个词元。
- 是时间步 时序列中第 个词元。
- 表示掩码词元。
- 是一个连续随机变量,均匀采样自
[0, 1]。它代表每个词元被掩码的概率。 这个公式的直觉是:每个词元要么保持不变 (概率为1-t),要么被掩码 (概率为 )。随着 从 0 趋近 1,词元被掩码的概率线性增加。在 时,所有词元都保证被掩码。这种线性掩码概率与连续扩散模型中的噪声调度类似但有所不同。
逆向过程: 前向过程不仅是可逆的,而且对应着一个在所有词元上完全分解的逆向过程。逆向过程从时间 到 0,从完全掩码的序列生成新数据。对于 ,逆向过程的条件分布分解为: 其中,对于每个词元 ,其条件分布为: 这里:
x _ { s }代表在时间步 时被部分掩码的序列。- 是给定 中第 个词元,预测 中第 个词元的条件概率。 这个公式描述了从一个更高掩码比率 () 的状态转移到一个更低掩码比率 () 的状态。关键在于如果 ,我们需要知道它可能是原始词元 ,这通过 来估计。
因此,需要估计的关键函数是条件分布 ,它预测如果输入 x _ { t } 中被掩码,原始词元会是什么。这类似于连续扩散模型中的数据预测形式。
根据 [20] 的证明,一个等价但与时间无关的参数化形式可以推导出来:
这里:
- 表示在
x _ { t }中未被掩码的词元集合。 - 是给定 中未掩码词元,预测原始数据 中第 个词元的概率。 这意味着估计数据预测函数等同于估计干净数据上的条件分布,而后者是时间不变的。因此,时间 不需要作为参数模型的输入。
损失函数:
掩码预测器是一个参数模型 (例如,一个没有因果掩码的 Transformer 模型),它以 x _ { t } 作为输入,并同时预测所有被掩码的词元。这个模型使用带有掩码的交叉熵损失进行训练:
其中:
-
x _ { 0 }从训练数据中采样。 -
均匀采样自
[0, 1]。 -
x _ { t }从 中采样。 -
是指示函数,当第 个词元 是掩码词元 时,其值为 1,否则为 0。它确保交叉熵损失只在被掩码的词元上计算。
-
是模型预测的原始词元 的概率,给定当前被掩码的序列
x _ { t }。这个损失函数 已被证明是模型分布负对数似然的一个上界: 这使得优化 成为一个原理性的生成模型训练目标。
4.2.2. 预训练 (Pre-training)
LLaDA 使用 Transformer 模型作为掩码预测器,与现有 LLM 类似,但不使用因果掩码,因为它允许模型看到整个输入进行预测。
-
模型架构: 训练了 10 亿 (1B) 和 80 亿 (8B) 参数的
LLaDA变体。LLaDA 8B的模型架构与LLaMA3 8B相似,但在多头注意力机制 (multi-head attention) 上使用普通的多头注意力而非分组查询注意力 (grouped query attention),因为它与键值缓存 (KV caching) 不兼容。这导致注意力层参数更多,因此相应减小了前馈网络 (FFN) 的维度以保持模型规模相当。词汇表大小也因使用适配数据的词元分析器 (tokenizer) 而有所不同。 以下是模型架构概述: 以下是原文 Table 5 的结果:Our ARM Baseline 1B LLaDA 1B Our ARM Baseline 7B LLaDA 8B LLaMA3 8B Layers 22 22 28 32 32 Model dimension 2048 2048 4096 4096 4096 Attention heads 32 32 32 32 32 Vocabulary size 126,464 126,464 126,464 126,464 128,000 FFN dimension 5634 5634 13,440 12,288 14,336 Key/Value heads 4 4 8 32 8 Total parameters 1.49 B 1.49 B 6.83 B 8.02 B 8.03 B Non-embedding parameters 0.97 B 0.97 B 5.80 B 6.98 B 6.98 B -
训练数据:
LLaDA模型在一个包含 2.3 万亿 (T) 词元的数据集上进行预训练。数据协议与现有LLM([25, 26]) 保持一致,未采用特殊技术。数据源自在线语料库,通过人工设计规则和LLM方法过滤低质量内容。数据集包含高质量代码、数学和多语言数据。数据混合方式参考了小规模ARM。 -
序列长度: 预训练过程使用 4096 词元的固定序列长度。为了增强模型处理变长数据的能力,1% 的预训练数据被设置为在 [1, 4096] 范围内均匀采样的随机长度。
-
计算成本:
LLaDA 8B的预训练总计算成本为 0.13 百万 H800 GPU 小时,与相同规模和数据集大小的ARM相似。 -
训练过程: 对于每个训练序列
x _ { 0 },随机采样 ,并以相同的概率 独立掩码每个词元以获得x _ { t }(如原文 Figure 2 (a) 所示),然后通过蒙特卡洛方法估计公式 (3) 进行随机梯度下降训练。 -
优化器与调度器: 采用
AdamW优化器 ([29]),权重衰减为 0.1,批次大小为 1280,每个GPU的局部批次大小为 4。使用Warmup-Stable-Decay([28]) 学习率调度器。学习率先线性增加到 ,在处理 1.2T 词元后衰减到 ,再在最后 0.3T 词元中线性减小到 。 -
超参数: 8B 实验进行了一次,没有进行超参数调优。
以下是原文 Algorithm 1 的预训练算法:
Require: mask predictor , data distribution
1: repeat
2: # with a probability of , the sequence length of `x _ { 0 }` follows U[1, 4096]
3:
4: # is defined in Eq. (7)
5: Calculate # is the sequence length of `x _ { 0 }`
6: Calculate and run optimizer.
7: until Converged
8: Return
4.2.3. 监督微调 (Supervised Fine-Tuning, SFT)
LLaDA 通过监督微调 (SFT) 增强其指令遵循能力,使用配对数据 ( p _ { 0 } , r _ { 0 } ),其中 p _ { 0 } 是提示 (prompt),r _ { 0 } 是响应 (response)。
-
目标: 建模条件分布 。
-
实现: 提示
p _ { 0 }保持不变,响应r _ { 0 }中的词元被独立掩码,得到r _ { t }。然后,将p _ { 0 }和r _ { t }输入到预训练的掩码预测器中,计算SFT损失 (如原文 Figure 2 (b) 所示): - \mathbb { E } _ { t , p _ { 0 } , r _ { 0 } , r _ { t } } \left[ \frac { 1 } { t } \sum _ { i = 1 } ^ { L ^ { \prime } } \mathbf { 1 } [ r _ { t } ^ { i } = \mathbf { M } ] \log p _ { \theta } ( r _ _ { 0 } ^ { i } | p _ { 0 } , r _ { t } ) \right] 其中 是响应r _ { 0 }的长度。 -
数据:
LLaDA 8B模型在包含 450 万对数据的数据集上进行SFT,数据准备和训练遵循现有LLM的SFT协议。数据集涵盖代码、数学和指令遵循等多个领域。 -
多轮对话: 将 轮对话 视为 个单轮对话对,例如 。
-
词元: 在每个小批次中较短的对话对末尾添加 (end-of-sequence) 词元以确保等长。 在训练时被视为正常词元并参与损失计算,在采样时移除,使
LLaDA自动控制响应长度。 -
优化器与调度器: 学习率先从 0 线性增加到 ,然后保持恒定,最后 10% 迭代线性减小到 。权重衰减设置为 0.1,全局批次大小为 256,每个
GPU的局部批次大小为 2。以下是原文 Algorithm 2 的监督微调算法:
Require: mask predictor , pair data distribution
1: repeat
2: # please refer to Appendix B.1 for details about the SFT data
3:
4: # is defined in Eq. (7)
5: Calculate # is the sequence length of `r _ { 0 }`
6: Calculate and run optimizer.
7: until Converged
8: Return
4.2.4. 推理 (Inference)
LLaDA 作为生成模型,可以以扩散方式生成新文本并评估候选文本的似然,而不是左到右的自回归方式。
生成过程:
从完全掩码的响应开始,LLaDA 离散化逆向过程,从模型分布 中采样 (如原文 Figure 2 (c) 所示)。
-
采样步数: 采样步数是一个超参数,提供了效率和样本质量之间的权衡。默认情况下使用均匀分布的时间步。
-
生成长度: 也是一个超参数,指定采样开始时完全掩码序列的长度。生成后, 词元后的所有词元被丢弃。
-
预测与重掩码: 在从时间 到 的中间步骤中,将
p _ { 0 }和r _ { t }输入掩码预测器,同时预测所有被掩码的词元。随后,根据预测结果,预期将 的预测词元重新掩码,得到r _ { s },以确保逆向过程的转换与前向过程对齐。 -
低置信度重掩码 (Low-confidence Remasking): 原则上重掩码策略应纯随机。但受
LLM采样中的退火技巧 ([4, 30]) 启发,LLaDA采用低置信度重掩码策略,即根据预测,将 的预测词元中置信度最低的词元重新掩码。 -
灵活采样:
LLaDA支持自回归和块扩散 ([31]) 采样,无需额外修改或训练,但扩散采样(逆向生成过程)在本文中默认采用并表现最佳。以下是原文 Algorithm 4 的随机重掩码生成算法:
Require: mask predictor , prompt `p _ { 0 }` , answer length , sampling steps
1: Set `r _ { 1 }` is a fully masked sequence of length .
2: for down to step do
3:
4: # we employ greedy sampling when predicting masked tokens
5: for to do
6: if then
7:
8: else
9: with probability , is set to M
10: end if
11: end for
12: `r _ { s } = r _ { 0 }`
13: end for
14: Return `r _ { 0 }`
以下是原文 Algorithm 5 的低置信度重掩码生成算法:
Require: mask predictor , prompt `p _ { 0 }` , answer length , sampling steps
1: Set `r _ { 1 }` is a fully masked sequence of length .
2: for down to step do
3:
4: for to do
5: if then
6:
7: else
8:
9:
10: end if
11: end for
12: # the number of unmasked tokens is `n _ { u n }` in timestep
13: for to do
14: if then
15: # the `n _ { u n }` positions with the least confidence are selected for remasking.
16: end if
17: end for
18: `r _ { s } = r _ { 0 }`
19: end for
20: Return `r _ { 0 }`
条件似然评估 (Conditional Likelihood Evaluation):
对于条件似然评估,可以自然地利用公式 (5) 中的上界。然而,作者发现以下等价形式 ([20]) 具有更低的方差和更高的稳定性:
其中:
- 是
r _ { 0 }的序列长度。 - 均匀采样自 。
r _ { l }是通过从r _ { 0 }中不放回地均匀采样 个词元进行掩码得到的。 虽然公式 (12) 和公式 (14) 共享相同的期望,但它们的方差不同。公式 (14) 中掩码词元的比例 是确定性的,而公式 (12) 中的随机性可能导致偏差。
以下是原文 Algorithm 3 的条件对数似然评估算法:
Require: mask predictor , prompt `p _ { 0 }` , response `r _ { 0 }` , the number of Monte Carlo estimations `n _ { m c }`
1: log_likelihood
2: for to `n _ { m c }` do
3: # is the sequence length of `r _ { 0 }`
4: Obtain `r _ { l }` by uniformly sampling tokens from `r _ { 0 }` without replacement for masking
5: log_likelihood
6: end for
7: log_likelihood
8: Return log_likelihood
与任意顺序自回归模型 (Any-order Autoregressive Models, AO-ARM) 的联系:
AO-ARM ([59, 101, 102]) 针对 个变量的所有可能顺序 自回归地刻画联合分布。为了学习这种分布,AO-ARM 使用共享权重的神经网络来建模所有单变量条件,并使用掩码词元来表示缺失变量。训练时,它最小化所有顺序均匀分布 上的负对数似然期望:
这里:
x _ { 0 }是原始数据。- 是词元排列的顺序,从所有可能的顺序中均匀采样。
- 是根据排列顺序 的第 个词元。
- 是根据排列顺序 的前
i-1个词元。 - 是给定前
i-1个词元和排列顺序 ,预测第 个词元的条件概率。 直观上, 可以看作是x _ { t }的未掩码部分,而索引在 中的词元被掩码。可以进一步证明,公式 (15) 等价于公式 (12)。这种联系解释了LLaDA的双向推理能力,尽管在推理过程中并未明确使用。
无分类器引导 (Classifier-Free Guidance, CFG):
CFG ([37, 27]) 是一种在扩散模型中广泛使用的技术,用于平衡与提示的对齐和文本多样性。它通过以下修改后的掩码预测器进行推理:
其中:
- 是有条件预测 (conditional prediction),即给定提示
p _ { 0 }和掩码响应r _ { t }。 - 是无条件预测 (unconditional prediction),即给定掩码序列 (与
p _ { 0 }长度相同) 和掩码响应r _ { t }。 - 是一个可调超参数,控制
p _ { 0 }的引导强度。 为了与ARM进行公平比较,本文主文本中没有对LLaDA应用CFG,但附录中证明了LLaDA与CFG兼容,并且应用CFG后性能会一致提升。
5. 实验设置
5.1. 数据集
预训练数据集 (Pre-training Dataset)
LLaDA 的预训练语料库由多种公开来源构建,包括:
- 来源: 网页数据、书籍、学术论文、社交媒体、百科全书、数学资料和代码。
- 语言构成: 大约 11% 的中文数据,61% 的英文数据,以及 28% 的代码数据。
- 数据清洗:
PDF文本提取。- 重复数据删除 (deduplication)。
- 有害内容过滤。
- 使用微调过的
BERT模型进行自动化数据质量标注,以选择高质量样本。
- 数据处理: 收集到的文档被拼接起来,然后文本被分割成固定长度的序列,按照预定义的序列长度进行处理。
监督微调数据集 (Supervised Fine-Tuning, SFT Dataset)
- 构成:
SFT数据集包含 450 万对数据。- 100 万对是人工标注的样本。
- 350 万对是使用类似
Xu et al. [103]和Wei et al. [104]提出的方法生成的合成样本。
- 领域: 数据集涵盖了代码、数学和指令遵循等多个领域。
- 数据处理:
- 采用动态序列长度策略,在每个
mini-batch中,较短的对话对末尾会附加 (end-of-sequence) 词元,以确保所有样本的序列长度统一。 - 词元在训练期间被视为正常词元,并被掩码并包含在训练目标中。在采样时, 词元会被从生成的输出中移除。这种策略使得模型能够通过生成 来学习控制响应的长度。
- 对于 轮对话 ,它被视为 个单轮对话对,例如 。这种划分策略使
LLaDA具备多轮对话能力。
- 采用动态序列长度策略,在每个
若原文提供了数据集中的具体样本示例,请务必一并展示,以帮助读者直观理解数据形态。
原文未提供预训练或 SFT 数据的具体样本示例,但在 Appendix B.8 提供了 iGSM 数据集的示例,这有助于理解数学问题的形式。
iGSM 数据集示例:
一个解决方案步骤为 4 的问题示例如下:
(问题) North Star Elementary 的每个文化研究教室的数量等于 1。Westridge Elementary 的每个舞蹈工作室的数量是 North Star Elementary 的每个教室和每个文化研究教室的总和的 3 倍。Westridge Elementary 有多少个舞蹈工作室?
(解决方案) 定义 North Star Elementary 的文化研究教室为 x;所以 。定义 North Star Elementary 的教室为 m;所以 。定义 Westridge Elementary 的舞蹈工作室为 n; : #### 1
此外,在 Appendix B.9 提供了诗歌补全任务的例子:
诗歌补全任务示例:
- 示例 1: 提示: -? 回答:
- 示例 2: 提示: —? 回答: #
为什么选择这些数据集进行实验?它们是否能有效地验证方法的性能?
- 多样性与规模: 选择的预训练数据集规模巨大 (2.3T 词元) 且来源多样(网页、书籍、论文、代码、数学、多语言),这与主流
LLM的训练数据保持一致,旨在验证LLaDA在通用语言理解和生成方面的可扩展性和泛化能力。 - 任务特异性:
SFT数据集专注于代码、数学和指令遵循,这些都是LLM评估其专业能力和与人类意图对齐的关键领域。 - 与
LLM保持一致: 遵循现有LLM的数据协议和清洗流程,确保了与ARM LLMs进行公平比较的基础,证明了LLaDA可以在相似的数据环境下取得竞争力。 - 验证核心能力: 这些数据集和处理策略能有效地验证
LLaDA是否具备LLM的核心能力,如可扩展性、上下文学习、指令遵循,以及是否能解决ARM的固有问题(如反转诅咒)。
5.2. 评估指标
论文使用了广泛且标准的基准测试来评估 LLaDA 的性能,涵盖了通用能力、数学与科学、代码生成和中文理解。对于所有评估指标,如果论文未提供具体数学公式,将提供其通用定义和公式。
-
准确率 (Accuracy):
- 概念定义 (Conceptual Definition): 准确率是最常见的分类指标,衡量模型正确预测的样本数量占总样本数量的比例。它直观地反映了模型在给定任务上的正确性。
- 数学公式 (Mathematical Formula):
- 符号解释 (Symbol Explanation):
Number of Correct Predictions: 模型做出正确预测的样本数量。Total Number of Samples: 所有评估样本的总数量。
-
Pass@k (通过率@k):
- 概念定义 (Conceptual Definition):
Pass@k是一种用于评估代码生成模型性能的指标。它表示模型在生成 个候选解决方案中,至少有一个是正确(通过所有测试用例)的概率。通常使用Pass@1来衡量模型的单次生成能力。 - 数学公式 (Mathematical Formula): 论文未直接给出公式,但根据定义,
Pass@k的计算通常涉及多次尝试和二项分布的原理。对于 个问题,每个问题生成 个样本,如果模型在 次尝试中成功通过了 个问题,那么Pass@k可以近似为: 或更精确地,考虑每次尝试的成功概率: 其中 是第 次尝试成功的概率。在实际应用中,通常通过重复采样来估计。如果对每个问题生成 个样本,并统计通过测试的样本数量 ,那么Pass@k的无偏估计为: 或者,如果每个问题只生成一个样本来计算Pass@1,它就退化为准确率: - 符号解释 (Symbol Explanation):
- : 问题的总数量。
- : 为每个问题生成的候选解决方案的数量。
- : 指示函数,如果条件为真则为 1,否则为 0。
- : 第 次尝试成功的概率。
- : 对于第 个问题,通过测试用例的样本数量。
- 概念定义 (Conceptual Definition):
具体评估任务和使用的指标:
-
通用能力 (General Tasks):
- MMLU (Massive Multitask Language Understanding): 准确率。衡量模型在大学/专业级别知识和常识推理上的能力(多选问答)。
- BBH (BIG-bench Hard): 准确率。一系列具有挑战性的任务,衡量模型在复杂推理和多步问题解决上的能力。
- ARC-C (AI2 Reasoning Challenge - Challenge Set): 准确率。科学领域的多选问答任务,需要常识推理。
- HellaSwag: 准确率。常识推理任务,要求模型选择最合理的句子完成给定上下文。
- TruthfulQA: 准确率。衡量模型生成真实而非虚假或误导性信息的能力。
- WinoGrande: 准确率。常识推理任务,解决指代消歧问题。
- PIQA (Physical Interaction Question Answering): 准确率。常识推理任务,关于物理世界中的操作和交互。
-
数学与科学能力 (Mathematics & Science):
- GSM8K (Grade School Math 8K): 准确率(通常是精确匹配)。小学数学文字问题,需要多步推理。
- Math: 准确率(通常是精确匹配)。更复杂的数学问题,涵盖代数、几何等。
- GPQA (Graduate-level Google-Proof Q&A): 准确率。研究生级别的问答基准,问题难以通过直接搜索获得答案。
-
代码生成 (Code Generation):
- HumanEval:
Pass@1。评估模型生成可执行代码以解决编程问题的能力。 - HumanEval-FIM (Fill-in-the-Middle):
Pass@1。代码补全任务,要求模型在给定代码片段的中间部分生成缺失的代码。 - MBPP (Mostly Basic Python Problems):
Pass@1。一组 Python 编程问题,衡量模型生成可运行代码的能力。
- HumanEval:
-
中文理解 (Chinese):
- CMMLU (Chinese Massive Multitask Language Understanding): 准确率。中文版
MMLU,涵盖多学科知识和推理。 - C-Eval: 准确率。多级别、多学科的中文评估套件。
- CMMLU (Chinese Massive Multitask Language Understanding): 准确率。中文版
评估流程:
- 基础模型 (Base Model): 对于
MMLU、CMMLU、C-Eval、ARC-C、Hellaswag、TruthfulQA、WinoGrande、PIQA和GPQA,使用条件似然估计 (conditional likelihood estimation)。模型计算每个候选答案的条件似然,选择似然最高的作为答案。对于其他基准,使用条件生成 (conditional generation)。 - 指令模型 (Instruct Model): 所有基准都使用条件生成进行评估。
- 似然估计:
LLaDA使用蒙特卡洛估计来近似公式 (6)。对于只需要单个词元似然的基准(如MMLU),一次蒙特卡洛估计就足够。其他基准使用 128 个蒙特卡洛样本以获得稳定结果。 - 生成:
LLaDA Base和LLaDA Instruct都使用纯扩散采样 (pure diffusion sampling) 和低置信度重掩码策略 (low-confidence remasking strategy)。LLaDA Base: 生成长度和采样步数均设为 1024。LLaDA Instruct: 采样步数等于答案长度,根据任务不同而配置(例如MMLU/HellaSwag为 3,GPQA为 64,MBPP/MMLU-Pro为 256,HumanEval/GSM8K/Math/ARC-C为 512)。
5.3. 对比基线
论文将 LLaDA 方法与以下模型进行了比较:
-
自建 ARM 基线 (Our Self-constructed ARM Baselines):
- 为了公平比较,作者训练了 1B 和 7B 参数的自回归模型,确保它们与
LLaDA在架构、数据和训练配置上尽可能一致(在 1B 规模上完全一致,在更大规模上由于资源限制略有差异)。这对于评估LLaDA的可扩展性至关重要。
- 为了公平比较,作者训练了 1B 和 7B 参数的自回归模型,确保它们与
-
主流开源 LLMs:
- LLaMA2 7B Base: Meta AI 发布的开源自回归模型。
- LLaMA3 8B Base / Instruct: Meta AI 发布的最新一代开源自回归模型。
- Qwen2 7B / Qwen2.5 7B: 阿里巴巴通义千问系列模型,强大的开源自回归模型。
- Mistral 7B: Mistral AI 发布的开源自回归模型,以其小尺寸高性能而闻名。
- Deepseek 7B: Deepseek AI 发布的开源自回归模型。
- Gemma2 9B: Google 发布的开源自回归模型。
-
GPT-4o: OpenAI 发布的最新多模态大模型。在反转诗歌补全任务中作为对比。
这些基线模型都是当前
LLM领域具有代表性的模型,其中LLaMA系列是开源LLM的标杆,Qwen、Mistral、Deepseek等也代表了各自领域的领先水平。通过与这些模型比较,LLaDA的性能和潜力得到了充分验证。
6. 实验结果与分析
6.1. 核心结果分析
LLaDA 的可扩展性 (Scalability)
原文 Figure 3 展示了 LLaDA 和自建 ARM 基线在增加预训练计算 FLOPs 时的性能趋势。
从图中可以看出,LLaDA 展示了令人印象深刻的可扩展性,其整体性能趋势与 ARM 高度竞争力。特别是在 MMLU 和 GSM8K 等任务上,LLaDA 展现出更强的可扩展性。即使在 PIQA 等相对较弱的任务上,随着规模的增加,与 ARM 的性能差距也在缩小。这表明 LLaDA 能够像 ARM 一样从更多的计算资源中受益,挑战了 LLM 可扩展性独属于 ARM 的观念。作者假设 LLaDA 在某些基准上的性能优势可能源于其架构差异,即 LLaDA 训练时考虑了多个条件方向,而 ARM 只优化左到右的条件概率,这可能提供更大的灵活性和更好的泛化能力。
以下是原文 Figure 3 的结果:
该图像是图表,展示了图3中LLaDA和自回归基线模型在六个任务上,随预训练计算FLOPs增加的性能表现。图中横轴为FLOPs,纵轴为不同任务的评价指标,LLaDA表现出强劲的可扩展性,整体性能与自回归模型相当。
基准测试结果:预训练模型 (Pre-trained LLMs)
原文 Table 1 比较了 LLaDA 8B Base 与现有 LLM 在零样本/少样本学习任务上的性能。
以下是原文 Table 1 的结果:
| LLaDA 8B* | LLaMA3 8B* | LLaMA2 7B* | Qwen2 7B† | Qwen2.5 7B† | Mistral 7B† | Deepseek 7B1 | |
| Model | Diffusion | AR | AR | AR | AR | AR | AR |
| Training tokens | 2.3T | 15T | 2T | 7T | 18T | - | 2T |
| General Tasks | |||||||
| MMLU | 65.9 (5) | 65.4 (5) | 45.9 (5) | 70.3 (5) | 74.2 (5) | 64.2 (5) | 48.2 (5) |
| BBH | 49.7 (3) | 62.1 (3) | 39.4 (3) | 62.3 (3) | 70.4 (3) | 56.1 (3) | 39.5 (3) |
| ARC-C | 45.9 (0) | 53.1 (0) | 46.3 (0) | 60.6 (25) | 63.7 (25) | 60.0 (25) | 48.1 (0) |
| Hellaswag | 70.5 (0) | 79.1 (0) | 76.0 (0) | 80.7 (10) | 80.2 (10) | 83.3 (10) | 75.4 (0) |
| TruthfulQA | 46.1 (0) | 44.0 (0) | 39.0 (0) | 54.2 (0) | 56.4 (0) | 42.2 (0) | - |
| WinoGrande PIQA | 74.8 (5) 73.6 (0) | 77.3 (5) 80.6 (0) | 72.5 (5) 79.1 (0) | 77.0 (5) - | 75.9 (5) - | 78.4 (5) | 70.5 (0) 79.2 (0) |
| Mathematics & Science | |||||||
| GSM8K | 70.3 (4) | 48.7 (4) | 13.1 (4) | 80.2 (4) | 85.4 (4) | 36.2 (4) | 17.4 (8) |
| Math | 31.4 (4) | 16.0 (4) | 4.3 (4) | 43.5 (4) | 49.8 (4) | 10.2 (4) | 6.0 (4) |
| GPQA | 25.2 (5) | 25.9 (5) | 25.7 (5) | 30.8 (5) | 36.4 (5) | 24.7 (5) | - |
| Code | |||||||
| HumanEval | 35.4 (0) | 34.8 (0) | 12.8 (0) | 51.2 (0) | 57.9 (0) | 29.3 (0) | 26.2 (0) |
| HumanEval-FIM | 73.8 (2) | 73.3 (2) | 26.9 (2) | ||||
| MBPP | 40.0 (4) | 48.8 (4) | 23.2 (4) | 64.2 (0) | 74.9 (0) | 51.1 (0) | 39.0 (3) |
| Chinese | |||||||
| CMMLU | 69.9 (5) | 50.7 (5) | 32.5 (5) | 83.9 (5) | 47.2 (5) | ||
| C-Eval | 70.5 (5) | 51.7 (5) | 34.0 (5) | 83.2 (5) | 45.0 (5) | ||
- 与
LLaMA2 7B Base对比:LLaDA 8B Base在几乎所有任务上都超越了LLaMA2 7B Base。 - 与
LLaMA3 8B Base对比:LLaDA 8B Base总体上与LLaMA3 8B Base具有竞争力。 - 优势领域:
LLaDA在数学任务(GSM8K70.3 vs.LLaMA348.7;Math31.4 vs.LLaMA316.0)和中文任务(CMMLU69.9 vs.LLaMA350.7;C-Eval70.5 vs.LLaMA351.7)上表现出显著优势。 - 上下文学习能力: 结果表明
LLaDA具备有效的上下文学习能力。 - 数据泄露排除: 为排除数据泄露的可能性,作者在
GSM8K这一例子上进行了验证,结果显示LLaDA表现优于ARM基线,并且在未见过的GSM8K类似任务iGSM([34]) 上结论依然成立 (参见Appendix B.8)。
基准测试结果:后训练模型 (Post-trained LLMs)
原文 Table 2 比较了 LLaDA 8B Instruct 与现有 LLM 在 SFT 后的性能。
以下是原文 Table 2 的结果:
| LLaDA 8B* | LLaMA3 8B* | LLaMA2 7B* | Qwen2 7B† | Qwen2.5 7B† | Gemma2 9B† | Deepseek 7B1 | |
| Model | Diffusion | AR | AR | AR | AR | AR | AR |
| Training tokens | 2.3T | 15T | 2T | 7T | 18T | 8T | 2T |
| Post-training Alignment pairs | SFT 4.5M | SFT+RL | SFT+RL | SFT+RL 0.5M+- | SFT+RL 1M + 0.15M | SFT+RL | SFT+RL 1.5M+- |
| - - General Tasks | |||||||
| MMLU | 65.5 (5) | 68.4 (5) | 44.1 (5) | - | 49.4 (0) | ||
| MMLU-pro | 37.0 (0) | 41.9 (0) | 4.6 (0) | 44.1 (5) | 56.3 (5) | 52.1 (5) | |
| Hellaswag | 74.6 (0) | 75.5 (0) | 51.5 (0) | 68.5 (-) | |||
| ARC-C | 88.5 (0) | 82.4 (0) | 57.3 (0) | - | - | 49.4 (-) | |
| Mathematics & Science | |||||||
| GSM8K | 69.4 (4) | 78.3 (4) | 29.0 (4) | 85.7 (0) | 91.6 (0) | 76.7 (0) | 63.0 (0) |
| Math | 31.9 (0) | 29.6 (0) | 3.8 (0) | 529 (0) | 75.5 (0) | 44.3 (0) | 15.8 (0) |
| GPQA | 33.3 (5) | 31.9 (5) | 28.4 (5) | 34.3 (0) | 36.4 (0) | 32.8 (0) | |
| Code | |||||||
| HumanEval | 49.4 (0) | 59.8 (0) | 16.5 (0) | 79.9 (0) | 84.8 (0) | 68.9 (0) | 48.2 (-) |
| MBPP | 41.0 (4) | 57.6 (4) | 20.6 (4) | 67.2 () | 79.2 (0) | 74.9 (0) | 35.2 (-) |
SFT效果:SFT提升了LLaDA在大多数下游任务上的性能。少数指标(如MMLU)出现下降,可能与SFT数据质量次优有关。- 与
LLaMA3 8B Instruct对比: 尽管LLaDA只进行了SFT而未进行强化学习 (RL) 对齐,其结果略低于LLaMA3 8B Instruct(后者进行了 对齐),但在许多指标上差距较小。 - 指令遵循能力: 仅通过
SFT,LLaDA就展现了令人印象深刻的指令遵循能力(参见Sec. 3.4的案例研究)。
反转推理 (Reversal Reasoning)
原文 Table 4 比较了 LLaDA、GPT-4o 和 Qwen2.5-7B Instruct 在诗歌补全任务中的反转推理能力。
以下是原文 Table 4 的结果:
| Forward | Reversal | |
| GPT-4o (2024-08-06) | 82.7 | 34.3 |
| Qwen2.5-7B Instruct | 75.9 | 38.0 |
| LLaDA-8B Instruct | 51.8 | 45.6 |
- 解决反转诅咒:
LLaDA有效解决了“反转诅咒”问题,在正向 (Forward) 和反向 (Reversal) 任务上表现出一致的零样本性能。 - 超越
GPT-4o: 尽管LLaDA在正向生成上不如GPT-4o和Qwen2.5(这可能与训练数据和计算资源规模的差异有关),但在反向任务上,LLaDA以显著优势超越了两者 (LLaDA45.6 vs.GPT-4o34.3,Qwen2.538.0)。 - 原因分析: 作者认为
LLaDA在处理词元时没有归纳偏置 (inductive bias),这导致了其在正向和反向任务上表现均衡。
采样效率分析 (Sampling Efficiency Analysis)
原文 Figure 5 展示了 LLaDA 在不同采样步数下,生成质量与速度之间的灵活权衡。
以下是原文 Figure 5 的结果:
该图像是多子图的折线图,展示了LLaDA 8B与LLaMA3 8B在不同吞吐率下的GSM8K、Math、HumanEval和MBPP任务性能对比,体现了采样效率与生成速度的权衡。
-
LLaDA允许通过调整采样步数来权衡生成质量和速度。 -
在
GSM8K和Math数据集上,LLaDA 8B Base在与LLaMA3 8B Base性能相当的情况下,吞吐量 (throughput) 高出 1.5 到 1.8 倍,即使LLaMA3使用了KV Cache而LLaDA未使用推理优化技术。 -
在
HumanEval上,当吞吐量匹配时,两者性能相当。 -
在
MBPP上,LLaDA略逊于LLaMA3。原文 Table 11 比较了
LLaDA 8B Base和LLaMA3 8B Base的内存消耗。 以下是原文 Table 11 的结果:
| Input Length | Output Length | LLaDA 8B | LLaMA3 8B w/o KV-Cache | LLaMA3 8B w/ KV-Cache |
| 512 | 512 | 17.03 | 16.70 | 16.32 |
| 1024 | 17.53 | 17.49 | 16.43 | |
| 2048 | 18.52 | 20.00 | 16.73 | |
| 1024 | 512 | 17.53 | 17.16 | 16.36 |
| 1024 | 18.01 | 18.26 | 16.41 | |
| 2048 | 19.02 | 21.39 | 16.74 |
LLaDA的内存使用量与不带KV Cache的LLaMA3 8B Base相当,但略高于带KV Cache的LLaMA3。LLaDA的内存使用量不随采样步数变化。- 强调: 本研究的目标并非提出比
ARM更快的模型,而是证明扩散模型在大规模语言建模中的潜力,并挑战核心LLM能力依赖于ARM的假设。
iGSM 数据集评估
原文 Table 12 比较了 LLaDA 8B Base 和 LLaMA3 8B Base 在 iGSM 数据集上的性能。
以下是原文 Table 12 的结果:
| 4 steps | 5 steps | 6 steps | |
| LLaMA3 8B Base | 38.0 | 35.0 | 34.0 |
| LLaDA 8B Base | 64.0 | 41.0 | 44.0 |
LLaDA 8B Base 在未见过的数学问题(iGSM)上展现出显著且持续的优势,这与 Table 1 中 LLaDA 在数学任务上的表现一致。
6.2. 数据呈现
预训练模型基准测试结果
以下是原文 Table 1 的结果:
| LLaDA 8B* | LLaMA3 8B* | LLaMA2 7B* | Qwen2 7B† | Qwen2.5 7B† | Mistral 7B† | Deepseek 7B1 | |
| Model | Diffusion | AR | AR | AR | AR | AR | AR |
| Training tokens | 2.3T | 15T | 2T | 7T | 18T | - | 2T |
| General Tasks | |||||||
| MMLU | 65.9 (5) | 65.4 (5) | 45.9 (5) | 70.3 (5) | 74.2 (5) | 64.2 (5) | 48.2 (5) |
| BBH | 49.7 (3) | 62.1 (3) | 39.4 (3) | 62.3 (3) | 70.4 (3) | 56.1 (3) | 39.5 (3) |
| ARC-C | 45.9 (0) | 53.1 (0) | 46.3 (0) | 60.6 (25) | 63.7 (25) | 60.0 (25) | 48.1 (0) |
| Hellaswag | 70.5 (0) | 79.1 (0) | 76.0 (0) | 80.7 (10) | 80.2 (10) | 83.3 (10) | 75.4 (0) |
| TruthfulQA | 46.1 (0) | 44.0 (0) | 39.0 (0) | 54.2 (0) | 56.4 (0) | 42.2 (0) | - |
| WinoGrande PIQA | 74.8 (5) 73.6 (0) | 77.3 (5) 80.6 (0) | 72.5 (5) 79.1 (0) | 77.0 (5) - | 75.9 (5) - | 78.4 (5) | 70.5 (0) 79.2 (0) |
| Mathematics & Science | |||||||
| GSM8K | 70.3 (4) | 48.7 (4) | 13.1 (4) | 80.2 (4) | 85.4 (4) | 36.2 (4) | 17.4 (8) |
| Math | 31.4 (4) | 16.0 (4) | 4.3 (4) | 43.5 (4) | 49.8 (4) | 10.2 (4) | 6.0 (4) |
| GPQA | 25.2 (5) | 25.9 (5) | 25.7 (5) | 30.8 (5) | 36.4 (5) | 24.7 (5) | - |
| Code | |||||||
| HumanEval | 35.4 (0) | 34.8 (0) | 12.8 (0) | 51.2 (0) | 57.9 (0) | 29.3 (0) | 26.2 (0) |
| HumanEval-FIM | 73.8 (2) | 73.3 (2) | 26.9 (2) | ||||
| MBPP | 40.0 (4) | 48.8 (4) | 23.2 (4) | 64.2 (0) | 74.9 (0) | 51.1 (0) | 39.0 (3) |
| Chinese | |||||||
| CMMLU | 69.9 (5) | 50.7 (5) | 32.5 (5) | 83.9 (5) | 47.2 (5) | ||
| C-Eval | 70.5 (5) | 51.7 (5) | 34.0 (5) | 83.2 (5) | 45.0 (5) | ||
后训练模型基准测试结果
以下是原文 Table 2 的结果:
| LLaDA 8B* | LLaMA3 8B* | LLaMA2 7B* | Qwen2 7B† | Qwen2.5 7B† | Gemma2 9B† | Deepseek 7B1 | |
| Model | Diffusion | AR | AR | AR | AR | AR | AR |
| Training tokens | 2.3T | 15T | 2T | 7T | 18T | 8T | 2T |
| Post-training Alignment pairs | SFT 4.5M | SFT+RL | SFT+RL | SFT+RL 0.5M+- | SFT+RL 1M + 0.15M | SFT+RL | SFT+RL 1.5M+- |
| - - General Tasks | |||||||
| MMLU | 65.5 (5) | 68.4 (5) | 44.1 (5) | - | 49.4 (0) | ||
| MMLU-pro | 37.0 (0) | 41.9 (0) | 4.6 (0) | 44.1 (5) | 56.3 (5) | 52.1 (5) | |
| Hellaswag | 74.6 (0) | 75.5 (0) | 51.5 (0) | 68.5 (-) | |||
| ARC-C | 88.5 (0) | 82.4 (0) | 57.3 (0) | - | - | 49.4 (-) | |
| Mathematics & Science | |||||||
| GSM8K | 69.4 (4) | 78.3 (4) | 29.0 (4) | 85.7 (0) | 91.6 (0) | 76.7 (0) | 63.0 (0) |
| Math | 31.9 (0) | 29.6 (0) | 3.8 (0) | 529 (0) | 75.5 (0) | 44.3 (0) | 15.8 (0) |
| GPQA | 33.3 (5) | 31.9 (5) | 28.4 (5) | 34.3 (0) | 36.4 (0) | 32.8 (0) | |
| Code | |||||||
| HumanEval | 49.4 (0) | 59.8 (0) | 16.5 (0) | 79.9 (0) | 84.8 (0) | 68.9 (0) | 48.2 (-) |
| MBPP | 41.0 (4) | 57.6 (4) | 20.6 (4) | 67.2 () | 79.2 (0) | 74.9 (0) | 35.2 (-) |
采样过程可视化与多轮对话案例
原文 Table 3 展示了 LLaDA 的采样过程和生成的对话。
以下是原文 Table 3 的结果:
| Sampling Process | |||||
| User LLaDA | Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours? In 4 hours Lily runs 12 * 48 kilometers After 4 she runs 6 * 4 24 kilometers In total she runs 48 | ||||
| hours + 24 Multi-round Dialogue Case | 72 kilometers in 8 hours Result | ||||
| User Could you please share with me what the initial two lines of the renowned poem 'The Road Not Taken'? | |||||
| LLaDA | Certainly! The initial two lines of the renowned poem "The Road Not Taken" by Robert Frost are: "Two roads diverged in a yellow wood, And sorry I could not travel both." These lines set the stage for the poem, introducing the idea of a speaker facing a choice between | ||||
| User LLaDA | two paths. Please help me translate into Chinese. Sure! The first two lines of "The Road Not Taken" by Robert Frost can be translated into | ||||
| User LLaDA | Chinese as:"," Please translate into German. Sure! The first two lines of "The Road Not Taken" by Robert Frost can be translated into German as: "Zwei Wege trennten sich im gelben Wald, und leider konnte ich nicht beide | ||||
| User | tun." Please also write a poem about life choices. I require 5 sentences, and the first word of each sentence starts with C as the first letter. | ||||
| LLaDA | Certainly! Here's a poem about life choices, with each sentence starting with the letter "C": Choices cascade through life's journey, Creating connections with each decision, Careful consideration leads to better paths, Courageous choices shape our destiny, Choosing wisely | ||||
- 采样过程可视化: 第一部分展示了一个数学问题的采样过程。响应中颜色较深的词元表示在采样后期预测,颜色较浅的词元表示在采样早期预测,这直观地展示了扩散模型的迭代生成过程。
- 多轮对话案例: 第二部分展示了
LLaDA在多轮对话中的表现,包括诗歌引用、多语言翻译(中、德)以及基于指令的诗歌创作。这强调了模型在保持上下文一致性、理解和执行复杂指令方面的能力。
诗歌补全任务结果
以下是原文 Table 4 的结果:
| Forward | Reversal | |
| GPT-4o (2024-08-06) | 82.7 | 34.3 |
| Qwen2.5-7B Instruct | 75.9 | 38.0 |
| LLaDA-8B Instruct | 51.8 | 45.6 |
模型架构
以下是原文 Table 5 的结果:
| Our ARM Baseline 1B | LLaDA 1B | Our ARM Baseline 7B | LLaDA 8B | LLaMA3 8B | |
| Layers | 22 | 22 | 28 | 32 | 32 |
| Model dimension | 2048 | 2048 | 4096 | 4096 | 4096 |
| Attention heads | 32 | 32 | 32 | 32 | 32 |
| Vocabulary size | 126,464 | 126,464 | 126,464 | 126,464 | 128,000 |
| FFN dimension | 5634 | 5634 | 13,440 | 12,288 | 14,336 |
| Key/Value heads | 4 | 4 | 8 | 32 | 8 |
| Total parameters | 1.49 B | 1.49 B | 6.83 B | 8.02 B | 8.03 B |
| Non-embedding parameters | 0.97 B | 0.97 B | 5.80 B | 6.98 B | 6.98 B |
无分类器引导 (CFG) 消融实验
以下是原文 Table 6 的结果:
| ARC-C | Hellaswag | TruthfulQA | WinoGrande | GPQA | PIQA | |
| w/o CFG | 45.9 | 70.5 | 46.1 | 74.8 | 25.2 | 73.6 |
| w/ CFG | 47.9 | 72.5 | 46.4 | 74.8 | 26.1 | 74.4 |
采样策略消融实验
以下是原文 Table 7 的结果:
| BBH | GSM8K | Math | HumanEval | MBPP | ||
| Autoregressive | 38.1 | 63.1 | 23.6 | 18.3 | 33.4 | |
| Block Difusion | L′= 2 | 37.3 | 62.6 | 25.2 | 14.6 | 33.6 |
| L' = 4 | 40.0 | 65.7 | 26.6 | 15.9 | 36.0 | |
| L'′ = 8 | 42.0 | 68.2 | 27.7 | 19.5 | 39.2 | |
| L′ = 32 | 45.7 | 68.6 | 29.7 | 29.9 | 37.4 | |
| Block Diffusion LLaDA | L′= 2 | 48.0 | 70.0 | 30.8 | 26.2 | 40.0 |
| L′ = 4 | 48.5 | 70.3 | 31.3 | 27.4 | 38.8 | |
| L'′ = 8 | 48.6 | 70.2 | 30.9 | 31.1 | 39.0 | |
| L' = 32 | 48.3 | 70.3 | 31.2 | 32.3 | 40.0 | |
| Pure Diffusion | 49.7 | 70.3 | 31.4 | 35.4 | 40.0 | |
以下是原文 Table 8 的结果:
| GSM8K | Math | HumanEval | MBPP | GPQA | MMLU-Pro | ARC-C | |
| Autoregressive | 0 | 9.5 | 0 | 0 | 0 | 0 | 84.4 |
| Block Diffusion | 24.6 | 23.5 | 17.1 | 21.2 | 29.3 | 32.5 | 88.1 |
| Block Difusion LLaDA | 77.5 | 42.2 | 46.3 | 34.2 | 31.3 | 34.8 | 85.4 |
| Pure Diffusion | 69.4 | 31.9 | 49.4 | 41.0 | 33.3 | 37.0 | 88.5 |
随机重掩码与低置信度重掩码策略分析
以下是原文 Table 9 的结果:
| Length | BBH | GSM8K | Math | HumanEval | MBPP |
| Random Remasking | 32.1 | 21.3 | 9.2 | 11.6 | 21.0 |
| Low-confidence Remasking | 45.0 | 70.0 | 30.3 | 32.9 | 40.2 |
生成长度消融实验
以下是原文 Table 10 的结果:
| Length | BBH | GSM8K | Math | HumanEval | MBPP |
| 256 | 45.0 | 70.0 | 30.3 | 32.9 | 40.2 |
| 512 | 50.4 | 70.8 | 30.9 | 32.9 | 39.2 |
| 1024 | 49.7 | 70.3 | 31.4 | 35.4 | 40.0 |
采样效率分析
以下是原文 Table 11 的结果:
| Input Length | Output Length | LLaDA 8B | LLaMA3 8B w/o KV-Cache | LLaMA3 8B w/ KV-Cache |
| 512 | 512 | 17.03 | 16.70 | 16.32 |
| 1024 | 17.53 | 17.49 | 16.43 | |
| 2048 | 18.52 | 20.00 | 16.73 | |
| 1024 | 512 | 17.53 | 17.16 | 16.36 |
| 1024 | 18.01 | 18.26 | 16.41 | |
| 2048 | 19.02 | 21.39 | 16.74 |
iGSM 数据集比较
以下是原文 Table 12 的结果:
| 4 steps | 5 steps | 6 steps | |
| LLaMA3 8B Base | 38.0 | 35.0 | 34.0 |
| LLaDA 8B Base | 64.0 | 41.0 | 44.0 |
块扩散 LLaDA 采样过程可视化
以下是原文 Table 13 的结果:
该图像是图表,展示了LLaDA模型对“定期锻炼对身心健康的益处”问题的采样生成过程。颜色深浅表示预测阶段,浅色为早期预测,深色为后期预测,体现了模型的逐步生成机制。
该图展示了 LLaDA 模型对“定期锻炼对身心健康的益处”问题的采样生成过程。颜色深浅表示词元预测的阶段,浅色为早期预测,深色为后期预测,体现了模型的逐步生成机制。
多轮对话可视化
以下是原文 Table 14 的结果:
该图像是论文中展示多轮对话的示意图,使用随机重掩码策略。LLaDA模型生成的响应中,颜色越深表示该词在采样后期预测,颜色越浅表示先期预测。
该图展示了 LLaDA 模型使用随机重掩码策略进行多轮对话的生成过程。与 Table 13 类似,词元颜色深浅表示预测阶段,浅色词元在采样早期被预测,深色词元在后期被预测。这表明模型在多轮对话中,可以迭代地预测和完善响应中的词元。
单轮对话案例
以下是原文 Table 15 的结果:
| Daily life User | |
| LLaDA | What's a piece of advice you've received that you've found helpful? One helpful piece of advice I've received is to always be open to learning and growth. I've found that being willing to explore new ideas and experiences has helped me become more resilient and adaptable. By embracing a growth mindset, I've been able to identify |
| Daily life User | |
| LLaDA | Please recommend me three famous movies Certainly! Here are three famous movies that have achieved widespread popularity and critical acclaim: 1. **The Shawshank Redemption** (1994) - Inspired by a novella and directed by Frank Darabont, this film tells the story of Andy Dufresne, a man convicted of murdering his |
| wife, dealing with his harsh life in Shawshank Prison and forming an unlikely friendship with a fellow inmate. 2. **The Godfather** (1972) - Directed by Francis Ford Coppola, this film is based on Mario Puzo's novel of the same name. It follows the story of the Corleone family, | |
| particularly Michael Corleone, who rises to power and takes over the family's criminal empire. 3. **The Empire Strikes Back** (1980) - Directed by George Lucas, this film continues the story of a young Skywalker in the Star Wars saga. It follows Luke Skywalker as he helps the Rebel Alliance fight against the Galactic Empire and destroy the powerful | |
| weapon, the Death Star. All of these films are celebrated for their cinematic achievements, powerful storytelling, and memorable characters. | |
| Translation User LLaDA | |
| Translation User | Please help me translate into Chinese: 'What is now proved was once only imagined' E |
| LLaDA Math | Please help me translate into German: 'What is now proved was once only imagined' Wer heute bestätigt ist, war einst nur immaginiert. Emma decides to run 4 sprints 5 times a week. She runs 75 meters each sprint. How many |
| User | |
| LLaDA | total meters does she run a week? Emma runs 4 sprints each time, so she runs a total of 4 x 5 = 20 sprints a week. Each sprint is 75 meters long, so Emma runs a total of 20 x 75 = 1500 meters a week. The final result is 1500 |
| Code User | I am writing python code, I have an array, please help me sort it from largest to smallest |
| LLaDA | Certainly! You can sort an array from the largest to smallest in Python using the 'sort()' method or the 'sorted()' function. Here's an example: ### Using 'sort()' method: my_array = [3, 1, 4, 1, 5, 2] |
这部分展示了 LLaDA 8B Instruct 在各种单轮对话任务中的能力,包括生活建议、电影推荐、多语言翻译、数学问题解决和代码生成。模型能够生成连贯、流畅且遵循指令的响应。
多轮对话案例
以下是原文 Table 16 的结果:
| Daily life User LLaDA | How do I prepare for a job interview? To prepare for a job interview, you should research the company and the job, practice |
| User LLaDA | common interview questions, and dress appropriately. What are some common interview questions? Some common interview questions include "Tell me about yourself" and "Why do you |
| Daily life User | want to work for us"? How do I start learning programming? |
| LLaDA | To learn programming, you can start by choosing a programming language that you are interested in. Some popular languages include Python, Java, and C++. |
| User LLaDA | Which programming language should I learn first? It depends on your interests and goals. Python is a great language for beginners for its simplicity and versatility. |
| User LLaDA User | How long does it take to learn Python? It can take anywhere from a few weeks to several months to learn Python. What are some good Python projects for beginners? |
| LLaDA | Some good Python projects for beginners include building a simple calculator, a game, and a web application. |
| Math User | Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. |
| LLaDA | How many clips did Natalia sell altogether in April and May? Natalia sold clips to 48 friends in April. She sold half as many clips in May, so she sold |
| 48/2 = 24 clips in May. Altogether, Natalia sold 48 + 24 = 72 clips in April and May. The final result is 72. | |
| User | Natalia sold three times as many items in June as she did in May. How many clips did |
| LLaDA | Natalia sell altogether in these three months? Natalia sold clips to 48 friends in April. She sold half as many clips in May, so she sold 48/2 = 24 clips in May. Natalia sold three times as many items in June as she did in May, |
这部分展示了 LLaDA 8B Instruct 在多轮对话中的能力,包括连续的日常问题交流和多步数学问题解决。模型能够有效地保留对话历史,并生成与上下文相关且准确的响应。
诗歌反转补全案例:LLaDA 成功而 GPT-4o 失败
以下是原文 Table 17 的结果:
| Question1 '-? | |
| User GPT-40 | - |
| LLaDA Question2 | ""A-EUT |
| User | 'HE-EH? |
| GPT-40 | -" |
| LLaDA | # "-" |
这部分展示了 LLaDA 8B Instruct 在诗歌反转补全任务中取得成功,而 GPT-4o 失败的案例。这有力地支持了 LLaDA 能够有效解决“反转诅咒”的发现。
LLaDA 可扩展性详细结果
以下是原文 Table 18 的结果:
| Model | Training Tokens | FLOPs | MMLU | CMMLU | ARC-C | PIQA | GSM8K | HumanEval |
| LLaDA 1B | 37.75B | 2.20e20 | 25.52 | 25.95 | 25.17 | 59.41 | 1.82 | 0.00 |
| LLaDA 1B | 88.08B | 5.13e20 | 27.11 | 26.52 | 26.96 | 61.86 | 3.03 | 1.83 |
| LLaDA 1B | 138.41B | 8.06e20 | 29.32 | 27.13 | 30.20 | 63.38 | 2.35 | 0.00 |
| LLaDA 1B | 239.08B | 1.39e21 | 31.48 | 30.77 | 27.99 | 63.11 | 3.26 | 1.22 |
| LLaDA 1B | 352.32B | 2.05e21 | 35.86 | 34.35 | 31.31 | 65.34 | 3.64 | 3.05 |
| LLaDA 1B | 461.37B | 2.69e21 | 31.86 | 30.98 | 30.12 | 65.51 | 2.35 | 0.61 |
| LLaDA 8B | 62.91B | 2.63e21 | 32.22 | 28.5 | 30.20 | 63.82 | 3.87 | 2.44 |
| LLaDA 8B | 125.83B | 5.27e21 | 33.39 | 33.9 | 34.64 | 66.54 | 8.72 | 3.66 |
| LLaDA 8B | 251.66B | 1.05e22 | 42.84 | 40.59 | 40.10 | 69.04 | 15.31 | 3.66 |
| LLaDA 8B | 377.49B | 1.58e22 | 45.11 | 43.99 | 39.25 | 68.61 | 25.40 | 9.76 |
| LLaDA 8B | 503.32B | 2.11e22 | 43.57 | 41.38 | 42.06 | 70.24 | 27.52 | 9.76 |
| LLaDA 8B | 629.14B | 2.63e22 | 48.80 | 47.13 | 42.24 | 72.09 | 30.10 | 12.80 |
| LLaDA 8B | 679.48B | 2.85e22 | 49.61 | 48.19 | 41.30 | 70.84 | 26.31 | 8.54 |
| LLaDA 8B | 792.72B | 3.31e22 | 50.88 | 49.01 | 42.58 | 70.51 | 31.99 | 6.10 |
| LLaDA 8B | 981.47B | 4.11e22 | 49.47 | 48.10 | 40.27 | 71.38 | - | 6.10 |
| LLaDA 8B | 1107.30B | 4.64e22 | 51.13 | 47.57 | 41.13 | 69.26 | 36.69 | 10.37 |
| LLaDA 8B | 1233.13B | 5.16e22 | 50.52 | 49.72 | 45.05 | 71.49 | 38.97 | 9.76 |
| LLaDA 8B | 1358.95B | 5.69e22 | 54.61 | 53.97 | 49.40 | 74.05 | 48.14 | 17.68 |
| LLaDA 8B | 1547.70B | 6.48e22 | 57.38 | 56.04 | 49.49 | 74.59 | 53.30 | 20.73 |
| LLaDA 8B | 1975.52B | 8.27e22 | 58.52 | 57.87 | 50.68 | 75.35 | - | 19.51 |
自回归基线详细结果
以下是原文 Table 19 的结果:
| Model | Training Tokens | FLOPs | MMLU | CMMLU | ARC-C | PIQA | GSM8K | HumanEval |
| ARM 1B | 37.75B | 2.20e20 | 25.47 | 25.38 | 30.20 | 67.36 | 2.20 | 4.88 |
| ARM 1B | 88.08B | 5.13e20 | 24.67 | 25.23 | 33.96 | 70.02 | 7.51 | 10.37 |
| ARM 1B | 138.41B | 8.06e20 | 29.25 | 27.48 | 33.45 | 70.29 | 8.34 | 9.76 |
| ARM 7B | 17.30B | 6.02e20 | 26.92 | 25.18 | 21.02 | 57.18 | 1.29 | 1.22 |
| ARM 7B | 34.60B | 1.20e21 | 25.83 | 25.38 | 24.07 | 62.84 | 1.59 | 2.44 |
| ARM 7B | 86.50B | 3.01e21 | 24.41 | 24.90 | 25.42 | 71.11 | 2.88 | 7.93 |
| ARM 7B | 173.02B | 6.02e21 | 26.20 | 24.78 | 26.10 | 74.27 | 6.67 | 9.15 |
| ARM 7B | 207.62B | 7.23e21 | 30.36 | 28.86 | 31.86 | 74.48 | 8.57 | 12.80 |
| ARM 7B | 224.92B | 7.83e21 | 29.49 | 32.26 | 31.19 | 74.37 | 8.95 | 8.54 |
| ARM 7B | 242.22B | 8.43e21 | 33.62 | 31.38 | 34.92 | 75.41 | 10.84 | 9.15 |
| ARM 7B | 259.52B | 9.03e21 | 34.11 | 34.20 | 32.88 | 75.19 | 9.33 | 10.98 |
| ARM 7B | 311.43B | 1.08e22 | 35.66 | 35.49 | 36.61 | 75.14 | 11.30 | 10.37 |
| ARM 7B | 363.33B | 1.26e22 | 34.54 | 37.67 | 34.58 | 76.55 | 12.28 | 14.02 |
| ARM 7B | 415.24B | 1.45e22 | 35.37 | 38.37 | 35.25 | 76.39 | 14.40 | 12.80 |
| ARM 7B | 449.84B | 1.57e22 | 39.51 | 39.24 | 34.92 | 76.82 | 14.94 | 14.63 |
| ARM 7B | 519.09B | 1.81e22 | 40.30 | 40.69 | 37.29 | 77.15 | 14.03 | 14.63 |
| ARM 7B | 778.57B | 2.71e22 | 43.33 | 43.50 | 38.31 | 77.53 | 17.59 | 14.63 |
| ARM 7B | 1038.09B | 3.61e22 | 45.06 | 46.12 | 41.69 | 77.86 | 20.02 | 15.85 |
| ARM 7B | 1384.12B | 4.82e22 | 47.63 | 48.18 | 47.80 | 76.93 | 22.82 | 15.24 |
| ARM 7B | 2076.18B | 7.23e22 | 47.68 | 50.85 | 44.07 | 77.37 | 24.79 | 14.63 |
| ARM 7B | 2214.59B | 7.71e22 | 49.26 | 52.08 | 53.56 | 77.69 | 27.37 | 17.07 |
6.3. 消融实验/参数分析
无分类器引导 (Classifier-Free Guidance, CFG)
- CFG 效果: 原文 Table 6 显示,应用
CFG始终能提升LLaDA 8B Base在ARC-C、Hellaswag、TruthfulQA、GPQA和PIQA上的性能。例如,ARC-C从 45.9 提升到 47.9,Hellaswag从 70.5 提升到 72.5。这表明LLaDA与CFG完全兼容,并且能够从中受益,进一步增强生成质量。为了公平比较,主文本中的结果未应用CFG。
采样策略
- 不同采样策略的性能: 原文 Table 7 (
LLaDA 8B Base) 和 Table 8 (LLaDA 8B Instruct) 比较了自回归采样、块扩散采样和纯扩散采样。- 块扩散优于自回归: 块扩散采样始终优于自回归采样。例如,在
LLaDA 8B Base的BBH上,自回归为 38.1,块扩散 (L'=32) 为 45.7。 - 块扩散 LLaDA 进一步提升: 块扩散
LLaDA采样在标准块扩散采样基础上进一步提升。 - 纯扩散采样最优: 纯扩散采样通常实现最佳的整体性能。例如,
LLaDA 8B Base在BBH上纯扩散为 49.7,优于所有块扩散变体。 - 块长度影响: 块扩散采样的性能随着块长度的增加而提高 (Table 7),例如
BBH从 的 37.3 提升到 的 45.7。
- 块扩散优于自回归: 块扩散采样始终优于自回归采样。例如,在
- 低置信度重掩码 vs. 随机重掩码: 原文 Table 9 比较了两种重掩码策略。
- 低置信度重掩码的优势: 低置信度重掩码策略始终优于随机重掩码策略。例如,在
BBH上,随机重掩码为 32.1,低置信度重掩码为 45.0。作者假设低置信度重掩码类似于ARM中的退火采样,通过减少生成句子的多样性来提高准确性。
- 低置信度重掩码的优势: 低置信度重掩码策略始终优于随机重掩码策略。例如,在
LLaDA Instruct的特殊处理:LLaDA 8B Instruct在纯扩散采样时,由于SFT数据中 词元的大量填充,导致输出中 比例过高,生成过短。为了缓解此问题,对于HumanEval、MBPP、GSM8K、Math和GPQA,采样时将 词元的置信度分数设为零,以维持合适的 词元比例。
生成长度
- 对长度超参数不敏感: 原文 Table 10 显示,
LLaDA 8B Base在不同生成长度 (256, 512, 1024) 下的性能结果并不敏感。例如,GSM8K分别为 70.0、70.8 和 70.3。这表明LLaDA在处理可变长度输出方面具有良好的鲁棒性,尽管生成长度仍是一个用户指定的超参数。
7. 总结与思考
7.1. 结论总结
本文引入了 LLaDA,一个从头开始训练的 80 亿参数级别的扩散语言模型。该研究的核心贡献在于挑战了大型语言模型 (LLMs) 核心能力(如可扩展性、上下文学习和指令遵循)必然依赖于自回归模型 (ARMs) 的普遍假设。
通过严格遵循预训练和监督微调 (SFT) 范式,LLaDA 在广泛的基准测试(包括通用任务、数学、代码和中文)中展现了强大的可扩展性,并与作者自建的 ARM 基线模型表现相当。特别是,LLaDA 8B 在上下文学习方面与 LLaMA3 8B 具有竞争力,并且在 SFT 后在多轮对话等案例研究中展现出令人印象深刻的指令遵循能力。
更值得注意的是,LLaDA 有效地解决了 ARM 的一个固有局限——“反转诅咒”问题,在反转诗歌补全任务中超越了 GPT-4o。这得益于其基于掩码预测的双向建模能力,使得模型在处理词元时不带有方向性偏置。
这些发现不仅揭示了扩散模型在大规模语言建模方面的巨大潜力,而且为自然语言处理领域开辟了新的范式,展示了高度的科学创新。
7.2. 局限性与未来工作
本文的作者也指出了 LLaDA 当前的局限性及其未来可能的研究方向:
- 生成长度: 目前生成长度是一个用户指定的超参数。尽管
LLaDA对此不敏感,但未来的工作可以探索自适应的生成长度机制,以提供更高效的解决方案。 - 计算资源限制: 由于计算限制,
LLaDA与ARM基线在完全相同数据集和架构下的直接比较仅限于较低的计算预算。未能将ARM基线扩展到与LLaDA相同的最大规模,使得某些比较可能不够全面。 - 架构优化:
LLaDA未设计专门的注意力机制或位置嵌入,也没有应用系统级架构优化(如KV Cache)。未来的工作可以探索这些优化,以进一步提升性能和效率。 - 推理效率和可控性: 在推理方面,更高效和可控的采样算法仍处于初步阶段。例如,需要改进采样速度和生成质量之间的权衡。
- 强化学习对齐:
LLaDA尚未进行基于强化学习 (RL) 的对齐,这对于提升LLM性能和与人类意图对齐至关重要。将其集成到RL框架中是未来的重要方向。 - 模型和数据规模:
LLaDA的模型规模和训练数据量仍小于领先的ARM模型(如LLaMA3、Qwen2.5等)。进一步的规模扩展将有助于全面评估其能力。 - 多模态能力:
LLaDA处理多模态数据的能力尚未探索。 - 应用集成: 其对提示词工程 (prompt tuning) 技术的影响以及集成到基于智能体 (agent-based) 系统中的潜力仍未完全理解。
- 后训练: 需要对
LLaDA进行系统性的后训练研究(例如O1类系统),以进一步释放扩散语言模型的潜力。
7.3. 个人启发与批判
个人启发
这篇论文极具启发性,它在 LLM 领域开辟了一条充满希望的新路径。
- 范式突破: 最重要的启发是,它有力地证明了**扩散模型完全有能力在大规模语言建模中扮演核心角色,并且能够展现出与自回归模型相同的核心
LLM能力。**这打破了业界对LLM必须是ARM的固有认知,开启了探索非自回归LLM的新篇章。 - 双向建模的价值:
LLaDA在解决“反转诅咒”上的卓越表现,凸显了双向建模在某些推理任务中的显著优势。对于需要更深层次理解和更灵活生成顺序的任务,扩散模型可能提供更鲁棒的解决方案。 - 理论与实践结合: 论文严格遵循了掩码扩散模型的理论基础,优化似然下界,而非使用启发式方法。这种理论严谨性是其在大规模成功扩展的关键,也为未来的扩散
LLM研究奠定了坚实基础。 - 工程实践意义:
LLaDA在与ARM相似的计算预算下从头训练 80 亿参数模型,并取得可比结果,展示了其在实际部署中的潜力。尤其是在某些任务上(如数学和中文)的优势,以及采样效率的灵活权衡,预示着扩散LLM可能在特定应用场景中具有独特价值。 - 推动领域发展: 这项工作无疑会激励更多的研究者投入到扩散
LLM的研究中,加速该领域的创新,并可能带来更高效、更通用的语言智能系统。
批判
尽管 LLaDA 取得了显著成就,但仍有一些方面值得批判性思考:
-
比较的公平性与绝对性能:
- 训练数据差异: 论文指出
LLaDA的训练词元 (2.3T) 远少于LLaMA3(15T) 和Qwen2.5(18T)。尽管LLaDA表现出竞争力,但这种数据量的差距可能意味着LLaDA在数据效率上更高,但也可能意味着在绝对性能上限上,目前的LLaDA仍有追赶空间。如果LLaDA能在与LLaMA3相同的数据规模上训练,其性能是否会超越LLaMA3仍有待验证。 - 计算资源差异: 论文也承认,由于计算限制,未能将
ARM基线扩展到与LLaDA相同的最大规模进行严格的同架构比较。这使得在最高计算预算下的“可扩展性”结论可能存在轻微的保留。
- 训练数据差异: 论文指出
-
效率挑战:
- 推理速度与
KV Cache: 论文承认LLaDA尚未使用KV Cache等ARM中普遍的推理优化技术。虽然LLaDA在某些任务上展示了比无KV Cache的LLaMA3更高的吞吐量,但与带有KV Cache的高度优化的ARM相比,其推理速度仍是一个挑战。扩散模型的迭代采样过程通常比ARM的单步生成更慢。如何进一步提升LLaDA的推理效率,使其在通用场景下更具竞争力,是关键问题。 - 长序列生成: 扩散模型的迭代性质在生成非常长的序列时,可能会带来累积误差和效率下降问题。
- 推理速度与
-
SFT数据质量: 论文提到SFT后MMLU性能下降,可能与SFT数据质量次优有关。这表明SFT数据的质量对扩散LLM同样关键,需要精心设计和筛选。 -
模型复杂性与可解释性: 尽管扩散模型理论严谨,但其迭代生成过程在某种程度上可能比
ARM的左到右生成更难直观理解和调试。 -
未探索的能力: 许多
LLM的关键能力(如多模态、工具使用、长上下文窗口)尚未在LLaDA中得到充分探索,这为未来工作留下了广阔空间,但也意味着目前LLaDA的通用性仍有待提升。总而言之,
LLaDA是一项开创性的工作,它为LLM领域注入了新的活力,证明了扩散模型的强大潜力。尽管存在一些局限性,但它无疑为未来的非自回归LLM研究指明了方向,并有望在特定应用场景中带来突破。
相似论文推荐
基于向量语义检索推荐的相关论文。