Progressive Distillation for Fast Sampling of Diffusion Models
TL;DR 精炼摘要
本文提出渐进式蒸馏方法,将高步数扩散采样器逐步蒸馏成低步数高效模型,实现采样步数从8192骤降至4步,显著提升采样速度且保持图像质量(如CIFAR-10上4步FID为3.0),为扩散模型加速提供实用方案。
摘要
Diffusion models have recently shown great promise for generative modeling, outperforming GANs on perceptual quality and autoregressive models at density estimation. A remaining downside is their slow sampling time: generating high quality samples takes many hundreds or thousands of model evaluations. Here we make two contributions to help eliminate this downside: First, we present new parameterizations of diffusion models that provide increased stability when using few sampling steps. Second, we present a method to distill a trained deterministic diffusion sampler, using many steps, into a new diffusion model that takes half as many sampling steps. We then keep progressively applying this distillation procedure to our model, halving the number of required sampling steps each time. On standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally, we show that the full progressive distillation procedure does not take more time than it takes to train the original model, thus representing an efficient solution for generative modeling using diffusion at both train and test time.
思维导图
论文精读
中文精读
1. 论文基本信息 (Bibliographic Information)
- 标题 (Title): Progressive Distillation for Fast Sampling of Diffusion Models (通过渐进式蒸馏实现扩散模型的快速采样)
- 作者 (Authors): Tim Salimans & Jonathan Ho
- 隶属机构 (Affiliations): Google Research, Brain team
- 发表期刊/会议 (Journal/Conference): 本文为 arXiv 上的预印本。arXiv 是一个开放获取的学术论文预印本平台,在计算机科学等领域具有极高的影响力,许多顶会/顶刊的论文都会先在此发布。
- 发表年份 (Publication Year): 2022
- 摘要 (Abstract): 扩散模型在生成建模方面表现出色,但在感知质量上优于 GAN,在密度估计上优于自回归模型。其主要缺点是采样速度慢,通常需要成百上千次模型评估才能生成高质量样本。本文提出两点贡献来解决此问题:首先,提出了新的扩散模型参数化方法,以在使用少量采样步骤时提高稳定性;其次,提出了一种渐进式蒸馏方法,将一个训练好的、使用多步骤的确定性扩散采样器(教师模型)蒸馏成一个仅需一半采样步骤的新扩散模型(学生模型)。通过反复应用此蒸馏过程,每次将所需采样步骤减半,最终能够将一个需要多达 8192 步的先进采样器,蒸馏到仅需 4 步即可生成高质量图像,且感知质量损失很小(例如,在 CIFAR-10 数据集上 4 步即可达到 3.0 的 FID)。整个渐进式蒸馏过程的训练时间不超过训练原始模型的时间,为扩散模型在训练和测试时都提供了一个高效的解决方案。
- 原文链接 (Source Link):
-
ArXiv 链接: https://arxiv.org/abs/2202.00512
-
发布状态: 预印本 (Preprint)。
-
2. 整体概括 (Executive Summary)
-
研究背景与动机 (Background & Motivation - Why):
- 核心问题: 扩散模型虽然生成质量高,但其核心痛点在于采样速度极慢。生成一张高质量图像通常需要数百甚至数千次迭代(即模型前向传播),这极大地限制了其在实际应用中的部署和使用,尤其是在对延迟敏感的场景。
- 重要性与挑战: 在无条件或类别条件下的图像生成任务中,这个问题尤为突出。现有的加速方法要么效果有限,要么会显著牺牲生成质量。因此,如何在不显著降低样本质量的前提下,将采样步骤减少几个数量级,是推动扩散模型实用化的关键挑战 (Gap)。
- 切入点: 本文的创新思路不是去寻找一个更优的通用 ODE 求解器,而是将一个已经训练好的、精准但缓慢的采样过程(教师)的行为,“教会”给一个新的模型(学生),让学生模型学会用更少的步骤“跳得更远”,从而实现加速。这个过程被设计为渐进式的 (progressive),逐步减半采样步数,以保证稳定性和效果。
-
核心贡献/主要发现 (Main Contribution/Findings - What):
-
提出了新的扩散模型参数化方法: 论文发现标准的 -prediction 参数化方法在少步数采样时(即低信噪比区域)不稳定。为此,作者提出了三种新的、更稳定的参数化方案(直接预测 、预测 等)和新的损失权重函数,为少步数生成奠定了基础。
-
提出了渐进式蒸馏 (
Progressive Distillation) 算法: 这是本文最核心的贡献。该算法可以将一个 步的确定性采样器(教师)蒸馏成一个 步的采样器(学生)。其关键思想是:让学生模型学习一步完成教师模型两步的采样效果。这个过程可以被反复迭代,从一个非常慢(如 8192 步)的模型开始,逐步蒸馏到极快(如 4 步)的模型。 -
实现了SOTA级别的快速采样效果: 实验证明,该方法可以在极少的采样步数下(4-8步)保持非常高的样本质量。例如,在 CIFAR-10 上仅用 4 步就达到了 3.0 的 FID 分数,在当时远超其他快速采样方法。
-
蒸馏过程高效: 整个渐进式蒸馏过程的总计算成本不高于从头训练原始模型,使其成为一种兼具训练和推理效率的实用方案。
-
3. 预备知识与相关工作 (Prerequisite Knowledge & Related Work)
-
基础概念 (Foundational Concepts):
- 扩散模型 (Diffusion Models): 一类生成模型。其核心思想包含两个过程:
- 前向过程 (Forward Process): 对一张真实图像逐步、反复地添加高斯噪声,直到图像完全变成纯噪声。这个过程是固定的、无需学习的。
- 反向过程 (Reverse Process): 训练一个神经网络(通常是 U-Net 结构),学习从纯噪声开始,逐步地、迭代地去除噪声,最终恢复出一张清晰的图像。这个去噪过程就是生成过程。
- 信噪比 (Signal-to-Noise Ratio, SNR): 在扩散模型中,任意一个加噪步骤的中间状态 可以表示为 ,其中 是原始图像, 是噪声。 和 分别控制信号和噪声的比例。信噪比通常定义为 ,其对数形式 是衡量噪声水平的关键指标。 越大,噪声越多,SNR 越低。
- DDPM 与 DDIM 采样器:
- DDPM (Denoising Diffusion Probabilistic Models): 原始的扩散模型采样器,每一步去噪都引入新的随机噪声,因此是随机的 (stochastic)。它严格模拟了反向过程的马尔可夫链。
- DDIM (Denoising Diffusion Implicit Models): DDIM 提出了一种更广义的、确定性的 (deterministic) 采样方式。它在采样过程中不引入新的随机噪声,使得从同一个初始噪声出发总能得到相同的最终图像。这种确定性是本文蒸馏方法能够成立的基础,因为它为学生模型提供了一个确定的、可学习的目标。
- 概率流常微分方程 (Probability Flow ODE): 扩散模型的反向去噪过程可以被抽象为一个连续时间的常微分方程 (ODE)。从这个视角看,任何采样器(如 DDPM 或 DDIM)本质上都是在用一种数值方法(如欧拉法)来求解这个 ODE。DDIM 被证明是求解该 ODE 的一种高效积分方法。
- 知识蒸馏 (Knowledge Distillation): 一种模型压缩技术。其核心思想是训练一个轻量级的“学生模型” (student model) 来模仿一个复杂的、性能更优的“教师模型” (teacher model) 的行为,而不是直接从原始数据标签中学习。
- 扩散模型 (Diffusion Models): 一类生成模型。其核心思想包含两个过程:
-
前人工作 (Previous Works):
- DDIM (Song et al., 2021a): 提供了确定性采样,是本文方法的技术基石。
- 快速 SDE 求解器 (Jolicoeur-Martineau et al., 2021): 尝试使用更高级的数值方法来求解扩散过程的随机微分方程 (SDE),以减少采样步数。
- 调整时间步 (Nichol & Dhariwal, 2021; Kong & Ping, 2021): 通过重新设计或学习采样的时间步序列,来在更少的步数内完成生成,但效果提升有限。
- 直接蒸馏 (Luhman & Luhman, 2021): 这是与本文最相关的工作。他们也使用知识蒸馏,但他们的方法是“一步到位”,即直接训练一个单步学生模型来模仿一个多步(如1000步)教师模型。这种方法的缺点是:为了训练学生,需要先用教师模型生成大量“噪声-图像”对的数据集,当教师模型步数非常多时,这个过程的计算成本是极其昂贵的。
-
技术演进 (Technological Evolution): 该领域的技术演进脉络可以概括为:从追求“生成质量”到兼顾“生成效率”。
- 初代 (DDPM): 实现了高生成质量,但采样极慢(~1000步)。
- 改进采样器 (DDIM, SDE Solvers): 提出了更高效的数值求解方法,可以在一定程度上减少步数(如50-200步)而不过多损失质量。
- 模型蒸馏 (Luhman & Luhman, 本文): 引入知识蒸馏思想,不再局限于改进求解器,而是训练一个新模型来“学会”快速采样。
-
差异化分析 (Differentiation): 与最相关的
Luhman & Luhman (2021)相比,本文的核心创新在于“渐进式” (Progressive):-
计算效率高: 本文的方法每次只蒸馏“两步变一步”,而不需要完整运行教师模型的全部采样步骤(例如 8192 步)。因此,蒸馏成本与教师模型的总步数无关,总蒸馏时间仅呈对数级增长,非常高效。
-
学习过程更稳定: 相比于让学生模型一次性学习一个从纯噪声到清晰图像的巨大跳跃,渐进式蒸馏让学生模型每次只学习一个较小的、更易于掌握的“两步跳”,使得训练过程更稳定,最终效果也更好。
-
4. 方法论 (Methodology - Core Technology & Implementation Details)
本部分详细拆解论文的核心技术:渐进式蒸馏 和 稳定的模型参数化。
方法原理 (Methodology Principles):
核心思想是“步数减半,效果对齐”。我们有一个需要 2N 步的教师模型,目标是训练一个只需 步的学生模型。我们希望学生模型的一步能够精确地等价于教师模型的两步。通过反复进行这个“两步并一步”的训练过程,模型的采样步数就能以 的方式指数级下降。
上图展示了两轮渐进式蒸馏。第一轮,一个 4 步的采样器被蒸馏成一个 2 步的采样器。第二轮,这个 2 步的采样器作为新的教师,被蒸馏成一个单步的生成模型。整个过程可以理解为,让模型学会用更少的积分步数来求解概率流 ODE。
方法步骤与流程 (Steps & Procedures):
Algorithm 2 (Progressive distillation) 详细描述了此过程。以下是其分步解读:
-
初始化:
- 从一个训练好的、需要 步的教师模型 开始。
- 设定当前学生模型的目标步数为 。
- 初始化学生模型 的权重,使其与教师模型完全一致,即 。
-
蒸馏训练循环:
-
a. 准备输入: 从数据集中随机抽取一个真实样本 。随机选择一个离散的时间步 (其中 从
1到 中均匀选取)。然后,像标准扩散模型训练一样,生成带噪样本 。这个 将作为教师和学生的共同起点。 -
b. 教师执行两步采样: 使用教师模型 ,从 出发,执行两次 DDIM 采样步骤,得到最终位置 。
- 第一步: 从时间 采样到 。
- 第二步: 从时间 采样到 。
-
c. 计算学生的目标: 这是最关键的一步。我们现在有了学生模型的单步起点 和单步终点 。我们希望学生模型 能够预测一个 ,使得通过单步 DDIM 公式,可以直接从 跳到 。为此,我们逆向求解单步 DDIM 公式,计算出这个理想的预测目标 。
-
d. 计算损失并更新学生: 学生模型在输入 时会做出自己的预测 。损失函数就是学生预测 与我们计算出的理想目标 之间的加权均方误差。使用梯度下降更新学生模型的参数 。
-
-
迭代:
-
当学生模型训练收敛后,它就成了一个合格的 步采样器。
-
我们将这个学生模型设为新的教师模型()。
-
将目标采样步数再次减半(),然后重复步骤 1-3。
这个过程持续进行,直到采样步数减少到我们期望的数值(如 4 步、2 步或 1 步)。
-
数学公式与关键细节 (Mathematical Formulas & Key Details):
-
DDIM 采样公式 (DDIM Update Rule): 从时间 到 () 的单步 DDIM 更新规则为: 其中 是模型在时间 对原始清晰图像 的预测。
-
蒸馏目标 的计算: 学生模型的目标是学习一步从 到达 。根据 DDIM 公式,这一步可以写成: 这是一个关于 的线性方程。通过简单的代数变换,我们可以解出 : 这个 就是学生模型需要学习的、能够实现“两步并一步”效果的伪目标 (pseudo-target)。
-
稳定的模型参数化与损失函数 (Stable Parameterization and Loss): 标准的扩散模型让网络预测噪声 ,即 。当采样步数很少时,模型会在非常接近纯噪声()的区域进行评估。此时, 作为分母会导致预测极不稳定。 为了解决这个问题,论文提出了几种替代方案:
-
直接预测 (Predicting directly): 网络直接输出 。
-
预测 (Predicting ): 定义 ,让网络预测 。这种参数化在各种信噪比下都表现得非常稳定。
-
混合预测 (Predicting both and ): 网络输出两个头,分别预测 和 ,然后根据信噪比加权组合。
同时,论文也修改了损失的权重 ,以确保在低信噪比区域()损失权重不为零,例如使用 权重,即 。
*上图左侧展示了不同损失权重策略对重建损失的加权系数随 log-SNR 的变化。标准的SNR权重(即\epsilon-prediction 的等价权重)在 log-SNR 为负无穷时权重趋于0,不利于蒸馏。而Truncated SNR和SNR+1$ 在低信噪比区域给予了足够的权重。右图展示了结合余弦时间表后,实际训练中不同信噪比区间受到的总权重。* 该图像是论文中的图表,左图展示了不同训练损失权重函数在不考虑调度的情况下,重建误差 对应的对数权重与对数信噪比(log SNR)的关系;右图则展示考虑余弦调度 后的权重变化。
-
5. 实验设置 (Experimental Setup)
-
数据集 (Datasets): 实验在多个标准图像生成数据集上进行,涵盖不同分辨率和内容:
CIFAR-10: 32x32 分辨率,10个类别的物体图像,常用于快速验证和消融实验。ImageNet: 64x64 分辨率,类别条件生成,是更具挑战性的基准。LSUN Bedrooms: 128x128 分辨率,卧室场景图像。LSUN Church-Outdoor: 128x128 分辨率,户外教堂场景图像。 这些数据集的选择覆盖了从简单到复杂的各种场景,能够全面验证方法的有效性和泛化能力。
-
评估指标 (Evaluation Metrics):
-
FID (Fréchet Inception Distance):
- 概念定义: FID 是衡量生成图像质量和多样性的黄金标准。它通过比较真实图像集和生成图像集在 Inception-v3 网络某一中间层提取的特征向量的分布来计算得分。具体来说,它计算两个高斯分布(分别拟合真实和生成图像的特征)之间的 Fréchet 距离。FID 分数越低,表示生成图像的分布与真实图像的分布越接近,即生成图像的质量和多样性越好。
- 数学公式:
- 符号解释:
- 和 分别代表真实图像集和生成图像集。
- 和 是真实图像和生成图像特征向量的均值。
- 和 是真实图像和生成图像特征向量的协方差矩阵。
- 表示欧几里得距离的平方。
- 表示矩阵的迹(主对角线元素之和)。
-
IS (Inception Score):
- 概念定义: IS 主要衡量生成图像的两个方面:清晰度 (Clarity) 和 多样性 (Diversity)。清晰度通过衡量单个生成图像的类别预测分布的熵来评估(低熵意味着模型对图像内容很确定)。多样性通过衡量所有生成图像的平均类别预测分布的熵来评估(高熵意味着生成了多种类别的图像)。IS 分数越高越好。
- 数学公式:
- 符号解释:
- 是生成图像集。
- 是从 中采样的单个图像。
- 是 Inception 模型对图像 的类别预测概率分布。
- 是所有生成图像的平均类别预测概率分布。
- 是 KL 散度,用于衡量两个概率分布的差异。
-
-
对比基线 (Baselines):
-
Undistilled DDIM sampler: 使用未经蒸馏的原始教师模型,但采用与学生模型相同数量的采样步骤。这是最重要的基线,用于直接衡量蒸馏带来的提升。
-
Optimized stochastic sampler: 使用未经蒸馏的原始教师模型,但采用经过优化的随机采样策略,并设置与学生模型相同的步数。
-
其他文献中的快速采样方法: 包括
DDIM(Song et al., 2021a),FastDPM(Kong & Ping, 2021),LSGM(Vahdat et al., 2021) 等,用于在 CIFAR-10 上进行横向性能比较。
-
6. 实验结果与分析
消融实验/参数分析 (Ablation Studies / Parameter Analysis):
论文首先通过消融实验验证了新提出的模型参数化和损失权重函数的有效性。
-
转录的表格 1: 不同参数化和损失权重在无条件 CIFAR-10 上的表现 (FID/IS)
Network Output Loss Weighting Stochastic sampler (FID/IS) DDIM sampler (FID/IS) (x, ε) combined SNR 2.54/9.88 2.78/9.56 Truncated SNR 2.47/9.85 2.76/9.49 SNR+1 2.52/9.79 2.87/9.45 x SNR 2.65/9.80 2.75/9.56 Truncated SNR 2.53/9.92 2.51/9.58 SNR+1 2.56/9.84 2.65/9.52 ε SNR 2.59/9.84 2.91/9.52 Truncated SNR N/A N/A SNR+1 2.56/9.77 3.27/9.41 v SNR 2.65/9.86 3.05/9.56 Truncated SNR 2.45/9.80 2.75/9.52 SNR+1 2.49/9.77 2.87/9.43 -
分析:
- 从上表可以看出,除了 预测与
Truncated SNR权重组合(导致训练不稳定)外,所有新提出的稳定参数化方案(, , )结合新的损失权重(Truncated SNR, )都取得了非常优秀的性能,其 FID/IS 分数与标准的 预测相当甚至更优。 - 这证明了这些新参数化是有效且稳定的,为后续的渐进式蒸馏奠定了坚实的基础。其中,直接预测 在这个实验中略占优势,而预测 理论上最稳定。
- 从上表可以看出,除了 预测与
核心结果分析 (Core Results Analysis):
-
渐进式蒸馏的有效性 (图 4):
该图像是由四个子图组成的折线图,展示了不同采样步骤下蒸馏模型(Distilled)、DDIM采样器和随机采样器在CIFAR-10、64×64 ImageNet、128×128 LSUN卧室和128×128 LSUN教堂数据集上的FID指标表现,横轴为采样步骤数量,纵轴为FID值,展示了蒸馏模型在采样效率和质量上的优势。- 分析: 上图是本文最重要的结果之一。它清晰地展示了在所有四个数据集上,渐进式蒸馏 (
Distilled) 方法的性能曲线(蓝色)远优于两个基线:未经蒸馏的 DDIM (DDIM)(橙色)和随机采样器 (Stochastic)(绿色)。 - 当采样步数减少到 128 步以下时,基线方法的 FID 分数急剧上升,意味着图像质量严重恶化。
- 相比之下,经过渐进式蒸馏的模型即使在步数减少到 8 步甚至 4 步时,其 FID 分数仍然保持在非常低的水平,接近于原始模型使用数百步才能达到的效果。这强有力地证明了渐进式蒸馏在实现高质量快速采样方面的巨大优势。
- 分析: 上图是本文最重要的结果之一。它清晰地展示了在所有四个数据集上,渐进式蒸馏 (
-
与其他方法的比较 (表 2):
-
转录的表格 2: CIFAR-10 上的快速采样方法比较
Method Model evaluations FID Progressive Distillation (ours) 1 9.12 2 4.51 4 3.00 Knowledge distillation (Luhman & Luhman, 2021) 1 9.36 DDIM (Song et al., 2021a) 10 13.36 50 4.67 FastDPM (Kong & Ping, 2021) 10 9.90 50 3.20 LSGM (Vahdat et al., 2021) 138 2.10 -
分析:
- 在极少步数(1-4 步)的赛道上,本文的方法取得了当时的最佳性能。仅用 4 步就达到了 3.00 的 FID,而 DDIM 需要 50-100 步才能达到相近水平。
- 与同样使用蒸馏的
Luhman & Luhman相比,本文方法在 1 步设置下也取得了更好的 FID (9.12 vs 9.36),且计算成本远低于对方。
-
-
视觉效果 (图 3):
该图像是图3的示意图,展示了在不同采样步数(256步、4步、1步)下,利用蒸馏模型生成的经过固定随机种子条件下的64×64犬类图像样本,图像质量随采样步数减少而基本保持稳定。-
分析: 上图展示了在 ImageNet 数据集上,从 256 步蒸馏到 4 步、再到 1 步的生成样本。可以看出,即使采样步数从 256 骤降到 4,生成图像的整体结构、细节和质量都得到了很好的保持,几乎没有肉眼可见的质量损失。这直观地展示了渐进式蒸馏的强大效果。
-
7. 总结与思考 (Conclusion & Personal Thoughts)
-
结论总结 (Conclusion Summary): 本文提出了一种名为渐进式蒸馏的高效算法,能够将扩散模型的采样步数降低数个数量级(例如从 8192 步降至 4 步),同时基本不损失样本的感知质量。该方法通过迭代地训练学生模型“一步完成教师模型的两步”,实现了采样过程的指数级加速。结合新提出的稳定模型参数化方法,该工作成功解决了扩散模型采样速度慢的核心痛点,极大地提升了其在实际应用中的可行性。
-
局限性与未来工作 (Limitations & Future Work):
- 模型尺寸: 当前工作中的学生模型与教师模型具有相同的网络结构和参数量。未来的一个方向是探索将模型蒸馏到更小、更轻量级的学生模型中,从而在减少采样步数的同时,进一步降低单步计算的成本。
- 数据模态: 本文的工作主要集中在图像生成上。未来可以将渐进式蒸馏的思想应用到其他数据模态的扩散模型上,例如音频生成 (
WaveGrad) 或文本生成等。
-
个人启发与批判 (Personal Insights & Critique):
-
启发:
- “摊销计算”思想的威力: 本质上,渐进式蒸馏是将昂贵的、需要多次迭代的数值求解过程(ODE 积分),通过学习“摊销” (amortize) 到了神经网络的权重中。模型不再需要机械地走小步,而是“学会”了如何走更准、更远的大步。这种思想在机器学习中具有普遍的指导意义。
- 化繁为简的智慧: 相比于
Luhman & Luhman的“一步到位”式蒸馏,本文“渐进式”的策略堪称精妙。它将一个极其困难的学习任务(从纯噪声直接映射到图像)分解为一系列简单得多的小任务(两步并一步),从而大大降低了学习难度,提升了稳定性和最终效果。这体现了在解决复杂问题时,分而治之、逐步推进的工程智慧。
-
批判性思考:
- 误差累积问题: 渐进式蒸馏是一个链式过程,前一阶段蒸馏出的学生会成为下一阶段的教师。那么,在每一阶段蒸馏过程中不可避免会产生微小的误差,这些误差是否会在多轮迭代后被累积和放大?论文并未对此进行深入的理论或实验分析。
- 对确定性采样的依赖: 该方法高度依赖于 DDIM 这类确定性采样器,因为它需要一个固定的采样轨迹作为学习目标。这使得它难以直接应用于随机采样过程,虽然附录中也进行了尝试,但效果不如确定性采样。这在一定程度上限制了其通用性。
- 训练成本: 尽管论文声称总训练成本不超过原始模型,但它仍然需要一个完整的多阶段训练流程,对于资源有限的研究者来说,仍是一笔不小的开销。探索更轻量级的蒸馏方案,或者能够“在线”蒸馏的方法,可能是一个有价值的方向。
-
相似论文推荐
基于向量语义检索推荐的相关论文。