Paper status: completed

Progressive Distillation for Fast Sampling of Diffusion Models

Published:02/02/2022
Original LinkPDF
Price: 0.100000
Price: 0.100000
Price: 0.100000
4 readers
This analysis is AI-generated and may not be fully accurate. Please refer to the original paper.

TL;DR Summary

This work introduces progressive distillation, halving diffusion model sampling steps successively from 8192 to 4, enabling fast sampling with minimal quality loss, exemplified by a 4-step FID of 3.0 on CIFAR-10, significantly accelerating diffusion model inference.

Abstract

Diffusion models have recently shown great promise for generative modeling, outperforming GANs on perceptual quality and autoregressive models at density estimation. A remaining downside is their slow sampling time: generating high quality samples takes many hundreds or thousands of model evaluations. Here we make two contributions to help eliminate this downside: First, we present new parameterizations of diffusion models that provide increased stability when using few sampling steps. Second, we present a method to distill a trained deterministic diffusion sampler, using many steps, into a new diffusion model that takes half as many sampling steps. We then keep progressively applying this distillation procedure to our model, halving the number of required sampling steps each time. On standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally, we show that the full progressive distillation procedure does not take more time than it takes to train the original model, thus representing an efficient solution for generative modeling using diffusion at both train and test time.

Mind Map

In-depth Reading

English Analysis

1. Bibliographic Information

  • Title: Progressive Distillation for Fast Sampling of Diffusion Models
  • Authors: Tim Salimans & Jonathan Ho (Google Research, Brain team)
  • Journal/Conference: The paper was submitted to the International Conference on Learning Representations (ICLR) and published in 2022. ICLR is a top-tier, highly competitive conference in the field of deep learning, known for publishing foundational and impactful research.
  • Publication Year: 2022
  • Abstract: The paper addresses the significant drawback of diffusion models: their slow sampling speed, which requires hundreds or thousands of function evaluations. The authors introduce two key contributions. First, they propose new model parameterizations that improve stability when using a small number of sampling steps. Second, they present progressive distillation, an iterative method where a trained, multi-step deterministic sampler (the "teacher") is distilled into a new diffusion model (the "student") that requires half the number of steps. By repeatedly applying this process, they successfully reduce the sampling steps from as many as 8192 down to as few as 4, while maintaining high perceptual quality on benchmarks like CIFAR-10, ImageNet, and LSUN. For example, they achieve a Fréchet Inception Distance (FID) of 3.0 on CIFAR-10 in just 4 steps. The authors also show that the entire distillation process is computationally efficient, taking no more time than the initial training of the original model.
  • Original Source Link: https://arxiv.org/abs/2202.00512
  • PDF Link: http://arxiv.org/pdf/2202.00512v2

2. Executive Summary

  • Background & Motivation (Why):

    • Core Problem: Diffusion models, despite achieving state-of-the-art results in generative modeling (outperforming GANs on image quality and autoregressive models on density estimation), are notoriously slow at generating samples. This is because they traditionally require an iterative denoising process involving hundreds or thousands of steps, where a neural network is evaluated at each step. This computational cost is a major barrier to their practical application.
    • Importance: Fast sample generation is crucial for real-world use cases of generative models, such as interactive content creation, real-time synthesis, and large-scale data augmentation. The slowness of diffusion models limits their deployment in resource-constrained environments or time-sensitive tasks.
    • Fresh Angle: Instead of merely optimizing the sampling process of a fixed model, this paper proposes to create entirely new, faster models through distillation. The core idea is to "bake" the complex, multi-step sampling logic of a slow, high-quality "teacher" model into a "student" model that learns to take larger, more effective steps. The "progressive" nature of this distillation allows for a dramatic and controlled reduction in sampling time.
  • Main Contributions / Findings (What):

    1. New Stable Parameterizations: The authors identify that the standard parameterization of diffusion models (predicting the noise εε) becomes unstable with few sampling steps, especially at low signal-to-noise ratios. They propose and validate alternative parameterizations (e.g., predicting the clean image xx directly or a velocity term vv) that are more robust and suitable for few-step sampling.
    2. Progressive Distillation Algorithm: This is the central contribution. It's an iterative algorithm that distills a diffusion model sampler with NN steps into a new model that achieves similar quality in N/2N/2 steps. The process works by training a "student" model to replicate the output of two "teacher" steps in a single, larger step. This procedure can be repeated, halving the number of steps at each stage (e.g., 1024 → 512 → 256 → ... → 4).
    3. State-of-the-Art Fast Sampling: The method achieves remarkable results. On CIFAR-10, it generates high-quality images with an FID of 3.0 in just 4 steps, and on class-conditional ImageNet 64x64, it produces excellent samples in as few as 4-8 steps. This represents an orders-of-magnitude speedup over previous diffusion model samplers.
    4. Efficient Training: The entire progressive distillation pipeline is shown to be computationally efficient, with a total training time comparable to or less than that of the original, slow model.

3. Prerequisite Knowledge & Related Work

  • Foundational Concepts:

    • Diffusion Models: These are generative models that learn to create data by reversing a gradual noising process.
      • Forward Process: A clean data sample (e.g., an image xx) is progressively corrupted by adding Gaussian noise over a series of timesteps tt. At the final timestep t=1t=1, the data becomes pure Gaussian noise. This process is fixed and does not involve learning.
      • Reverse Process: A neural network is trained to reverse this process. Given a noisy image ztz_t at timestep tt, the model learns to predict a slightly less noisy version zsz_s (where s<ts < t) or, more commonly, to predict the original clean image xx or the noise εε that was added.
      • Sampling: To generate a new image, one starts with pure noise z1z_1 and iteratively applies the trained network to denoise it step-by-step until a clean image z0z_0 is obtained.
    • Denoising Diffusion Implicit Models (DDIM): A type of sampler for diffusion models introduced by Song et al. (2021a). Unlike traditional ancestral samplers which are stochastic (add noise at each step), DDIM provides a deterministic sampling path. This means that for a given starting noise, it will always produce the same final image. This deterministic nature is crucial for the distillation method proposed in this paper. The paper shows that the DDIM update rule can be interpreted as a numerical solver for an Ordinary Differential Equation (ODE) known as the probability flow ODE.
    • Knowledge Distillation: A machine learning technique where a large, complex model (the "teacher") is used to train a smaller, faster model (the "student"). The student is trained not just on the ground-truth labels, but on the "soft" outputs or internal representations of the teacher. In this paper's context, the teacher is a slow, multi-step DDIM sampler, and the student is a new diffusion model that learns to emulate the teacher's behavior in fewer steps.
  • Previous Works & Differentiation:

    • Slow Samplers (DDPM, SDE-based): Early diffusion models like DDPM (Ho et al., 2020) and score-based models (Song et al., 2021c) required thousands of steps for high-quality generation.
    • Faster Samplers (DDIM): DDIM (Song et al., 2021a) was a major step forward, enabling good quality with fewer steps (e.g., 50-100). However, quality still degrades significantly with very few steps (< 20). The proposed method builds on DDIM but fundamentally changes the model itself to accommodate even faster sampling.
    • Sampler Optimization: Other works focused on optimizing the sampling schedule (e.g., choosing better timesteps) or using more advanced ODE solvers (Jolicoeur-Martineau et al., 2021) for a fixed model. This paper instead retrains the model to be inherently faster.
    • One-Step Distillation (Luhman & Luhman, 2021): A concurrent work also used distillation. However, their method distilled a many-step teacher directly into a one-step student. This required pre-generating a large dataset of (noise, image) pairs from the teacher, making the distillation cost scale linearly with the number of teacher steps, which is prohibitive for very slow teachers (e.g., >1000 steps). In contrast, progressive distillation has a cost that is independent of the number of teacher steps at each iteration, resulting in a total cost that scales logarithmically.

4. Methodology (Core Technology & Implementation)

The paper's methodology has two pillars: improving model parameterization for stability and the progressive distillation algorithm itself.

  • Principles & Model Parameterization: The standard approach for diffusion models is to train a network εθ(zt)ε_θ(z_t) to predict the noise εε that was added to create the noisy image zt=αtx+σtεz_t = α_t * x + σ_t * ε. The clean image prediction x_hat is then derived as: x^θ(zt)=1αt(ztσtϵ^θ(zt)) \hat{\mathbf{x}}_{\theta}(\mathbf{z}_t) = \frac{1}{\alpha_t}(\mathbf{z}_t - \sigma_t \hat{\epsilon}_{\theta}(\mathbf{z}_t)) The Problem: As the number of sampling steps decreases, the model must operate at very low signal-to-noise ratios (SNR), where αtα_t approaches 0. In this regime, dividing by αtα_t amplifies any small error in the noise prediction εhatε_hat, leading to instability. When distilling down to a single step, the input is pure noise (αt=0α_t = 0), and the formula breaks down completely.

    The Solutions (New Parameterizations): To solve this, the authors propose parameterizations that remain stable as αt0α_t → 0.

    1. Predicting xx directly: The network directly outputs the predicted clean image xhatθ(zt)x_hat_θ(z_t). This is the simplest and most direct solution.
    2. Predicting both xx and εε: The network has two output heads, one for x_tilde and one for εtildeε_tilde. The final prediction is a weighted combination, smoothly interpolating between the two approaches based on the noise level.
    3. Predicting vv: The network predicts v=αtεσtxv = α_t * ε - σ_t * x. This vv-prediction formulation has the nice property of making the DDIM step size independent of the SNR, leading to more stable training and sampling. The clean image prediction is then derived as xhat=αtztσtvhatθ(zt)x_hat = α_t * z_t - σ_t * v_hat_θ(z_t).
  • Steps & Procedures: Progressive Distillation

    The core of the paper is Algorithm 2, Progressive Distillation. It is an iterative process. Let's assume we start with a trained teacher model xhatηx_hat_η that uses 2N steps. The goal is to train a student model xhatθx_hat_θ that achieves similar quality in NN steps.

    Figure 1: A visualization of two iterations of our proposed progressive distillation algorithm. A sampler \(f ( \\mathbf { z } ; \\boldsymbol { \\eta } )\) , mapping random noise \(\\epsilon\) to samples \$\\m… Figure 1 from the paper illustrates this concept. A 4-step sampler is distilled into a 2-step sampler, which is then distilled into a 1-step sampler. The distillation process (yellow arrows) trains a new model to take larger, more efficient steps (green arrows).

    The procedure for one iteration of distillation is as follows:

    1. Initialization: The student model xhatθx_hat_θ is initialized with the weights of the teacher model xhatηx_hat_η (i.e., θηθ ← η).

    2. Training Loop:

      • Sample a clean image xx from the dataset.
      • Sample a discrete timestep tt from 1/N,2/N,...,N/N{1/N, 2/N, ..., N/N}.
      • Add noise to xx to get zt=αtx+σtεz_t = α_t * x + σ_t * ε.
      • Crucial Step: Calculate the Distillation Target. Instead of training the student to predict the original xx, we train it to predict a target x_tilde that makes a single student step equivalent to two teacher steps. This target is calculated as follows: a. Define two smaller timesteps for the teacher: t=t0.5/Nt' = t - 0.5/N and t=t1/Nt'' = t - 1/N. b. Perform the first teacher DDIM step: starting from ztz_t, calculate ztz_t' using the teacher model xhatη(zt)x_hat_η(z_t). zt=αtx^η(zt)+σtσt(ztαtx^η(zt)) \mathbf{z}_{t'} = \alpha_{t'} \hat{\mathbf{x}}_{\eta}(\mathbf{z}_t) + \frac{\sigma_{t'}}{\sigma_t}(\mathbf{z}_t - \alpha_t \hat{\mathbf{x}}_{\eta}(\mathbf{z}_t)) c. Perform the second teacher DDIM step: starting from ztz_t', calculate ztz_t'' using the teacher model xhatη(zt)x_hat_η(z_t'). zt=αtx^η(zt)+σtσt(ztαtx^η(zt)) \mathbf{z}_{t''} = \alpha_{t''} \hat{\mathbf{x}}_{\eta}(\mathbf{z}_{t'}) + \frac{\sigma_{t''}}{\sigma_{t'}}(\mathbf{z}_{t'} - \alpha_{t'} \hat{\mathbf{x}}_{\eta}(\mathbf{z}_{t'})) d. Now, ztz_t'' is the result of applying the teacher model twice. We want the student model to reach this point ztz_t'' from ztz_t in a single step. We find the required x_tilde prediction by inverting the single-step DDIM update rule. The target is: x~=zt(σt/σt)ztαt(σt/σt)αt \tilde{\mathbf{x}} = \frac{\mathbf{z}_{t''} - (\sigma_{t''}/\sigma_t)\mathbf{z}_t}{\alpha_{t''} - (\sigma_{t''}/\sigma_t)\alpha_t}
      • Loss Calculation: The student model xhatθx_hat_θ is then trained to predict this target x_tilde from the input ztz_t. The loss is a weighted mean squared error: Lθ=w(λt)x~x^θ(zt)22 L_{\theta} = w(\lambda_t) || \tilde{\mathbf{x}} - \hat{\mathbf{x}}_{\theta}(\mathbf{z}_t) ||_2^2 where λtλ_t is the log-SNR and w(λt)w(λ_t) is a weighting function. The paper finds that a weighting of w(λt)=SNR+1w(λ_t) = SNR + 1 or Truncated SNR works well, as they don't vanish at zero SNR.
    3. Progression: After the student model is trained, it becomes the new teacher (ηθη ← θ), and the number of steps is halved (NN/2N ← N/2). The process repeats until the desired number of sampling steps (e.g., 4, 2, or 1) is reached.

5. Experimental Setup

  • Datasets:

    • CIFAR-10: 32x32 images, 10 classes. Used for unconditional generation and the main ablation study.
    • ImageNet: Downsampled to 64x64, 1000 classes. Used for class-conditional generation.
    • LSUN Bedrooms: 128x128 images. Used for unconditional generation.
    • LSUN Church-Outdoor: 128x128 images. Used for unconditional generation. These are standard and challenging benchmarks for generative models, allowing for comparison with prior work.
  • Evaluation Metrics:

    1. Fréchet Inception Distance (FID):

      • Conceptual Definition: FID measures the similarity between two sets of images (e.g., generated vs. real). It computes the distance between the feature distributions of the two sets, where features are extracted from a pre-trained InceptionV3 network. A lower FID score indicates that the distribution of generated images is closer to the distribution of real images, implying better quality and diversity. It is the most common metric for evaluating modern generative models.
      • Mathematical Formula: FID(x,g)=μxμg22+Tr(Σx+Σg2(ΣxΣg)1/2) \mathrm{FID}(x, g) = || \mu_x - \mu_g ||_2^2 + \mathrm{Tr}(\Sigma_x + \Sigma_g - 2(\Sigma_x \Sigma_g)^{1/2})
      • Symbol Explanation:
        • μx\mu_x and μg\mu_g: The mean vectors of the Inception features for real (xx) and generated (gg) images, respectively.
        • Σx\Sigma_x and Σg\Sigma_g: The covariance matrices of the Inception features for real and generated images.
        • Tr()\mathrm{Tr}(\cdot): The trace of a matrix (sum of diagonal elements).
    2. Inception Score (IS):

      • Conceptual Definition: IS aims to measure both the "quality" (clarity and recognizability) and "diversity" of generated images. It uses a pre-trained Inception network to classify generated images. A high score is achieved if (1) the classifier is very confident about the class of each individual image (low entropy of p(yx)p(y|x)), and (2) the distribution of classes across all generated images is uniform (high entropy of p(y)). Higher is better.
      • Mathematical Formula: IS(G)=exp(ExG[DKL(p(yx)p(y))]) \mathrm{IS}(G) = \exp(\mathbb{E}_{x \sim G} [D_{KL}(p(y|x) || p(y))])
      • Symbol Explanation:
        • GG: The set of generated images.
        • xGx \sim G: An image sampled from the generator.
        • p(yx)p(y|x): The conditional class distribution (classifier's output probabilities) for image xx.
        • p(y): The marginal class distribution, averaged over all generated images.
        • DKL()D_{KL}(\cdot || \cdot): The Kullback-Leibler (KL) divergence between two distributions.
  • Baselines:

    • Undistilled DDIM sampler: The standard DDIM sampler applied to the original, undistilled teacher model, evaluated with varying numbers of steps.
    • Optimized stochastic sampler: An ancestral sampler (which adds noise at each step) applied to the original model. The amount of noise is tuned for optimal performance at each step count. This serves as a strong, non-deterministic baseline.
    • Other Literature Methods: The paper compares its final CIFAR-10 results against other fast sampling techniques like FastDPM, LSGM, and the original DDIM few-step results from their respective papers.

6. Results & Analysis

  • Core Results (Progressive Distillation):

    Figure 8: Random samples from our distilled \(6 4 \\times 6 4\) ImageNet models, conditioned on the 'coral reef' class, for fixed random seed and for varying number of sampling steps. Figure 4 from the paper shows the FID score versus the number of sampling steps across four datasets. The blue line (Distilled) represents the proposed method. The orange (DDIM) and green (Stochastic) lines are baselines using the original model.

    The key takeaway from Figure 4 is the dramatic performance difference.

    • The Distilled model (blue line) maintains a very low (good) FID score even as the number of steps is reduced to 8 or 4. The quality degradation is graceful.

    • In contrast, the baseline DDIM and Stochastic samplers see their FID scores explode (worsen dramatically) when the number of steps falls below ~128.

    • This demonstrates that progressive distillation successfully creates a model that is fundamentally better at taking large sampling steps, rather than just being a slightly better sampler for a fixed model.

      Figure 7: Random samples from our distilled CIFAR-10 models, for fixed random seed and for varying number of sampling steps. Figure 3 from the paper shows samples of a 'malamute' from the ImageNet model. The quality remains remarkably consistent from 256 steps down to 4 and even 1 step, showing the method preserves the mapping from noise to image.

    The following is a manual transcription of Table 2, which provides a direct comparison on CIFAR-10.

    Method Model evaluations FID
    Progressive Distillation (ours) 1 9.12
    Progressive Distillation (ours) 2 4.51
    Progressive Distillation (ours) 4 3.00
    Knowledge distillation (Luhman & Luhman, 2021) 1 9.36
    DDIM (Song et al., 2021a) 10 13.36
    20 6.84
    50 4.67
    100 4.16
    FastDPM (Kong & Ping, 2021) 10 9.90
    20 5.05
    50 3.20
    100 2.86
    Improved DDPM respacing (Nichol & Dhariwal, 2021) 25 7.53
    50 4.99
    LSGM (Vahdat et al., 2021) 138 2.10

    This table shows that progressive distillation is highly competitive. It achieves an FID of 3.00 in just 4 steps, a result that other methods like DDIM or FastDPM need 50+ steps to approach. This is a >10x speedup for similar quality.

  • Ablations / Parameter Sensitivity (Model Parameterization):

    The following is a manual transcription of Table 1, which ablates the model parameterization and loss weighting on unconditional CIFAR-10 (using the original, undistilled model).

    Network Output Loss Weighting Stochastic sampler DDIM sampler
    (x, ε) combined SNR 2.54/9.88 2.78/9.56
    Truncated SNR 2.47/9.85 2.76/9.49
    SNR+1 2.52/9.79 2.87/9.45
    x SNR 2.65/9.80 2.75/9.56
    Truncated SNR 2.53/9.92 2.51/9.58
    SNR+1 2.56/9.84 2.65/9.52
    ε SNR 2.59/9.84 2.91/9.52
    Truncated SNR N/A N/A
    SNR+1 2.56/9.77 3.27/9.41
    v SNR 2.65/9.86 3.05/9.56
    Truncated SNR 2.45/9.80 2.75/9.52
    SNR+1 2.49/9.77 2.87/9.43

    (Metrics are reported as FID/IS. Lower FID is better, higher IS is better.)

    This ablation study confirms the paper's claims:

    • The proposed stable parameterizations (predicting xx, vv, or a combination) all achieve excellent performance, comparable to or even better than the standard εε-prediction.
    • Predicting xx directly performed slightly better empirically in this specific study.
    • The standard εε-prediction combined with Truncated SNR loss was unstable and diverged (N/A), highlighting the instability issues discussed in the paper.
    • The Truncated SNR and SNR+1SNR+1 loss weightings, which do not ignore the low-SNR regime, are effective and crucial for the distillation process.

7. Conclusion & Reflections

  • Conclusion Summary: The paper successfully introduces progressive distillation, a novel and highly effective method for dramatically accelerating the sampling process of diffusion models. By combining this algorithm with more stable model parameterizations, the authors demonstrate the ability to reduce sampling steps by orders of magnitude (from thousands to as few as four) with minimal loss in image quality. This work effectively bridges a major gap between the high performance of diffusion models and their practical usability, making them a much more viable solution for real-world generative tasks.

  • Limitations & Future Work:

    • The authors note that the student model in their experiments always had the same architecture and size as the teacher. A potential future direction is to explore distilling into a smaller student model, which would provide further gains in computational efficiency at test time.
    • The work is focused on image generation. The authors suggest that progressive distillation could be applied to diffusion models for other data modalities, such as audio or text.
    • While quality is high at 4 steps, there is a noticeable drop-off at 1-2 steps. Further research could investigate ways to improve performance in this extreme few-shot sampling regime.
  • Personal Insights & Critique:

    • The core idea of "amortizing" the multi-step integration of the ODE into the network's parameters via distillation is elegant and powerful. It reframes the problem from finding a better solver to creating a better (simpler) problem to solve.
    • The "progressive" aspect of the algorithm is a key innovation. It allows the method to scale to very slow teacher models with thousands of steps without incurring a prohibitive cost, a significant advantage over prior distillation work.
    • The paper is a strong example of rigorous academic research: it clearly identifies a critical problem, proposes a well-motivated solution, validates it with thorough experiments and ablations, and provides clear, reproducible details.
    • This work had a major impact on the field, making diffusion models practical for many more applications. The ability to generate high-fidelity images in a handful of steps was a game-changer and spurred further research into few-step and even single-step generation, which is now a very active area. The techniques presented here have become a standard part of the toolbox for anyone working on efficient diffusion models.

Similar papers

Recommended via semantic vector search.

No similar papers found yet.