Progressive Distillation for Fast Sampling of Diffusion Models
TL;DR Summary
This work introduces progressive distillation, halving diffusion model sampling steps successively from 8192 to 4, enabling fast sampling with minimal quality loss, exemplified by a 4-step FID of 3.0 on CIFAR-10, significantly accelerating diffusion model inference.
Abstract
Diffusion models have recently shown great promise for generative modeling, outperforming GANs on perceptual quality and autoregressive models at density estimation. A remaining downside is their slow sampling time: generating high quality samples takes many hundreds or thousands of model evaluations. Here we make two contributions to help eliminate this downside: First, we present new parameterizations of diffusion models that provide increased stability when using few sampling steps. Second, we present a method to distill a trained deterministic diffusion sampler, using many steps, into a new diffusion model that takes half as many sampling steps. We then keep progressively applying this distillation procedure to our model, halving the number of required sampling steps each time. On standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally, we show that the full progressive distillation procedure does not take more time than it takes to train the original model, thus representing an efficient solution for generative modeling using diffusion at both train and test time.
Mind Map
In-depth Reading
English Analysis
1. Bibliographic Information
- Title: Progressive Distillation for Fast Sampling of Diffusion Models
- Authors: Tim Salimans & Jonathan Ho (Google Research, Brain team)
- Journal/Conference: The paper was submitted to the International Conference on Learning Representations (ICLR) and published in 2022. ICLR is a top-tier, highly competitive conference in the field of deep learning, known for publishing foundational and impactful research.
- Publication Year: 2022
- Abstract: The paper addresses the significant drawback of diffusion models: their slow sampling speed, which requires hundreds or thousands of function evaluations. The authors introduce two key contributions. First, they propose new model parameterizations that improve stability when using a small number of sampling steps. Second, they present
progressive distillation, an iterative method where a trained, multi-step deterministic sampler (the "teacher") is distilled into a new diffusion model (the "student") that requires half the number of steps. By repeatedly applying this process, they successfully reduce the sampling steps from as many as 8192 down to as few as 4, while maintaining high perceptual quality on benchmarks like CIFAR-10, ImageNet, and LSUN. For example, they achieve a Fréchet Inception Distance (FID) of 3.0 on CIFAR-10 in just 4 steps. The authors also show that the entire distillation process is computationally efficient, taking no more time than the initial training of the original model. - Original Source Link: https://arxiv.org/abs/2202.00512
- PDF Link: http://arxiv.org/pdf/2202.00512v2
2. Executive Summary
-
Background & Motivation (Why):
- Core Problem: Diffusion models, despite achieving state-of-the-art results in generative modeling (outperforming GANs on image quality and autoregressive models on density estimation), are notoriously slow at generating samples. This is because they traditionally require an iterative denoising process involving hundreds or thousands of steps, where a neural network is evaluated at each step. This computational cost is a major barrier to their practical application.
- Importance: Fast sample generation is crucial for real-world use cases of generative models, such as interactive content creation, real-time synthesis, and large-scale data augmentation. The slowness of diffusion models limits their deployment in resource-constrained environments or time-sensitive tasks.
- Fresh Angle: Instead of merely optimizing the sampling process of a fixed model, this paper proposes to create entirely new, faster models through distillation. The core idea is to "bake" the complex, multi-step sampling logic of a slow, high-quality "teacher" model into a "student" model that learns to take larger, more effective steps. The "progressive" nature of this distillation allows for a dramatic and controlled reduction in sampling time.
-
Main Contributions / Findings (What):
- New Stable Parameterizations: The authors identify that the standard parameterization of diffusion models (predicting the noise ) becomes unstable with few sampling steps, especially at low signal-to-noise ratios. They propose and validate alternative parameterizations (e.g., predicting the clean image directly or a velocity term ) that are more robust and suitable for few-step sampling.
- Progressive Distillation Algorithm: This is the central contribution. It's an iterative algorithm that distills a diffusion model sampler with steps into a new model that achieves similar quality in steps. The process works by training a "student" model to replicate the output of two "teacher" steps in a single, larger step. This procedure can be repeated, halving the number of steps at each stage (e.g., 1024 → 512 → 256 → ... → 4).
- State-of-the-Art Fast Sampling: The method achieves remarkable results. On CIFAR-10, it generates high-quality images with an FID of 3.0 in just 4 steps, and on class-conditional ImageNet 64x64, it produces excellent samples in as few as 4-8 steps. This represents an orders-of-magnitude speedup over previous diffusion model samplers.
- Efficient Training: The entire progressive distillation pipeline is shown to be computationally efficient, with a total training time comparable to or less than that of the original, slow model.
3. Prerequisite Knowledge & Related Work
-
Foundational Concepts:
- Diffusion Models: These are generative models that learn to create data by reversing a gradual noising process.
- Forward Process: A clean data sample (e.g., an image ) is progressively corrupted by adding Gaussian noise over a series of timesteps . At the final timestep , the data becomes pure Gaussian noise. This process is fixed and does not involve learning.
- Reverse Process: A neural network is trained to reverse this process. Given a noisy image at timestep , the model learns to predict a slightly less noisy version (where ) or, more commonly, to predict the original clean image or the noise that was added.
- Sampling: To generate a new image, one starts with pure noise and iteratively applies the trained network to denoise it step-by-step until a clean image is obtained.
- Denoising Diffusion Implicit Models (DDIM): A type of sampler for diffusion models introduced by Song et al. (2021a). Unlike traditional ancestral samplers which are stochastic (add noise at each step), DDIM provides a deterministic sampling path. This means that for a given starting noise, it will always produce the same final image. This deterministic nature is crucial for the distillation method proposed in this paper. The paper shows that the DDIM update rule can be interpreted as a numerical solver for an Ordinary Differential Equation (ODE) known as the
probability flow ODE. - Knowledge Distillation: A machine learning technique where a large, complex model (the "teacher") is used to train a smaller, faster model (the "student"). The student is trained not just on the ground-truth labels, but on the "soft" outputs or internal representations of the teacher. In this paper's context, the teacher is a slow, multi-step DDIM sampler, and the student is a new diffusion model that learns to emulate the teacher's behavior in fewer steps.
- Diffusion Models: These are generative models that learn to create data by reversing a gradual noising process.
-
Previous Works & Differentiation:
- Slow Samplers (DDPM, SDE-based): Early diffusion models like DDPM (Ho et al., 2020) and score-based models (Song et al., 2021c) required thousands of steps for high-quality generation.
- Faster Samplers (DDIM): DDIM (Song et al., 2021a) was a major step forward, enabling good quality with fewer steps (e.g., 50-100). However, quality still degrades significantly with very few steps (< 20). The proposed method builds on DDIM but fundamentally changes the model itself to accommodate even faster sampling.
- Sampler Optimization: Other works focused on optimizing the sampling schedule (e.g., choosing better timesteps) or using more advanced ODE solvers (Jolicoeur-Martineau et al., 2021) for a fixed model. This paper instead retrains the model to be inherently faster.
- One-Step Distillation (Luhman & Luhman, 2021): A concurrent work also used distillation. However, their method distilled a many-step teacher directly into a one-step student. This required pre-generating a large dataset of (noise, image) pairs from the teacher, making the distillation cost scale linearly with the number of teacher steps, which is prohibitive for very slow teachers (e.g., >1000 steps). In contrast,
progressive distillationhas a cost that is independent of the number of teacher steps at each iteration, resulting in a total cost that scales logarithmically.
4. Methodology (Core Technology & Implementation)
The paper's methodology has two pillars: improving model parameterization for stability and the progressive distillation algorithm itself.
-
Principles & Model Parameterization: The standard approach for diffusion models is to train a network to predict the noise that was added to create the noisy image . The clean image prediction
x_hatis then derived as: The Problem: As the number of sampling steps decreases, the model must operate at very low signal-to-noise ratios (SNR), where approaches 0. In this regime, dividing by amplifies any small error in the noise prediction , leading to instability. When distilling down to a single step, the input is pure noise (), and the formula breaks down completely.The Solutions (New Parameterizations): To solve this, the authors propose parameterizations that remain stable as .
- Predicting directly: The network directly outputs the predicted clean image . This is the simplest and most direct solution.
- Predicting both and : The network has two output heads, one for
x_tildeand one for . The final prediction is a weighted combination, smoothly interpolating between the two approaches based on the noise level. - Predicting : The network predicts . This -prediction formulation has the nice property of making the DDIM step size independent of the SNR, leading to more stable training and sampling. The clean image prediction is then derived as .
-
Steps & Procedures: Progressive Distillation
The core of the paper is Algorithm 2,
Progressive Distillation. It is an iterative process. Let's assume we start with a trained teacher model that uses2Nsteps. The goal is to train a student model that achieves similar quality in steps.
Figure 1 from the paper illustrates this concept. A 4-step sampler is distilled into a 2-step sampler, which is then distilled into a 1-step sampler. The distillation process (yellow arrows) trains a new model to take larger, more efficient steps (green arrows).The procedure for one iteration of distillation is as follows:
-
Initialization: The student model is initialized with the weights of the teacher model (i.e., ).
-
Training Loop:
- Sample a clean image from the dataset.
- Sample a discrete timestep from .
- Add noise to to get .
- Crucial Step: Calculate the Distillation Target. Instead of training the student to predict the original , we train it to predict a target
x_tildethat makes a single student step equivalent to two teacher steps. This target is calculated as follows: a. Define two smaller timesteps for the teacher: and . b. Perform the first teacher DDIM step: starting from , calculate using the teacher model . c. Perform the second teacher DDIM step: starting from , calculate using the teacher model . d. Now, is the result of applying the teacher model twice. We want the student model to reach this point from in a single step. We find the requiredx_tildeprediction by inverting the single-step DDIM update rule. The target is: - Loss Calculation: The student model is then trained to predict this target
x_tildefrom the input . The loss is a weighted mean squared error: where is the log-SNR and is a weighting function. The paper finds that a weighting of orTruncated SNRworks well, as they don't vanish at zero SNR.
-
Progression: After the student model is trained, it becomes the new teacher (), and the number of steps is halved (). The process repeats until the desired number of sampling steps (e.g., 4, 2, or 1) is reached.
-
5. Experimental Setup
-
Datasets:
- CIFAR-10: 32x32 images, 10 classes. Used for unconditional generation and the main ablation study.
- ImageNet: Downsampled to 64x64, 1000 classes. Used for class-conditional generation.
- LSUN Bedrooms: 128x128 images. Used for unconditional generation.
- LSUN Church-Outdoor: 128x128 images. Used for unconditional generation. These are standard and challenging benchmarks for generative models, allowing for comparison with prior work.
-
Evaluation Metrics:
-
Fréchet Inception Distance (FID):
- Conceptual Definition: FID measures the similarity between two sets of images (e.g., generated vs. real). It computes the distance between the feature distributions of the two sets, where features are extracted from a pre-trained InceptionV3 network. A lower FID score indicates that the distribution of generated images is closer to the distribution of real images, implying better quality and diversity. It is the most common metric for evaluating modern generative models.
- Mathematical Formula:
- Symbol Explanation:
- and : The mean vectors of the Inception features for real () and generated () images, respectively.
- and : The covariance matrices of the Inception features for real and generated images.
- : The trace of a matrix (sum of diagonal elements).
-
Inception Score (IS):
- Conceptual Definition: IS aims to measure both the "quality" (clarity and recognizability) and "diversity" of generated images. It uses a pre-trained Inception network to classify generated images. A high score is achieved if (1) the classifier is very confident about the class of each individual image (low entropy of ), and (2) the distribution of classes across all generated images is uniform (high entropy of
p(y)). Higher is better. - Mathematical Formula:
- Symbol Explanation:
- : The set of generated images.
- : An image sampled from the generator.
- : The conditional class distribution (classifier's output probabilities) for image .
p(y): The marginal class distribution, averaged over all generated images.- : The Kullback-Leibler (KL) divergence between two distributions.
- Conceptual Definition: IS aims to measure both the "quality" (clarity and recognizability) and "diversity" of generated images. It uses a pre-trained Inception network to classify generated images. A high score is achieved if (1) the classifier is very confident about the class of each individual image (low entropy of ), and (2) the distribution of classes across all generated images is uniform (high entropy of
-
-
Baselines:
- Undistilled DDIM sampler: The standard DDIM sampler applied to the original, undistilled teacher model, evaluated with varying numbers of steps.
- Optimized stochastic sampler: An ancestral sampler (which adds noise at each step) applied to the original model. The amount of noise is tuned for optimal performance at each step count. This serves as a strong, non-deterministic baseline.
- Other Literature Methods: The paper compares its final CIFAR-10 results against other fast sampling techniques like
FastDPM,LSGM, and the originalDDIMfew-step results from their respective papers.
6. Results & Analysis
-
Core Results (Progressive Distillation):
Figure 4 from the paper shows the FID score versus the number of sampling steps across four datasets. The blue line (Distilled) represents the proposed method. The orange (DDIM) and green (Stochastic) lines are baselines using the original model.The key takeaway from Figure 4 is the dramatic performance difference.
-
The
Distilledmodel (blue line) maintains a very low (good) FID score even as the number of steps is reduced to 8 or 4. The quality degradation is graceful. -
In contrast, the baseline
DDIMandStochasticsamplers see their FID scores explode (worsen dramatically) when the number of steps falls below ~128. -
This demonstrates that
progressive distillationsuccessfully creates a model that is fundamentally better at taking large sampling steps, rather than just being a slightly better sampler for a fixed model.
Figure 3 from the paper shows samples of a 'malamute' from the ImageNet model. The quality remains remarkably consistent from 256 steps down to 4 and even 1 step, showing the method preserves the mapping from noise to image.
The following is a manual transcription of Table 2, which provides a direct comparison on CIFAR-10.
Method Model evaluations FID Progressive Distillation (ours) 1 9.12 Progressive Distillation (ours) 2 4.51 Progressive Distillation (ours) 4 3.00 Knowledge distillation (Luhman & Luhman, 2021) 1 9.36 DDIM (Song et al., 2021a) 10 13.36 20 6.84 50 4.67 100 4.16 FastDPM (Kong & Ping, 2021) 10 9.90 20 5.05 50 3.20 100 2.86 Improved DDPM respacing (Nichol & Dhariwal, 2021) 25 7.53 50 4.99 LSGM (Vahdat et al., 2021) 138 2.10 This table shows that
progressive distillationis highly competitive. It achieves an FID of 3.00 in just 4 steps, a result that other methods likeDDIMorFastDPMneed 50+ steps to approach. This is a >10x speedup for similar quality. -
-
Ablations / Parameter Sensitivity (Model Parameterization):
The following is a manual transcription of Table 1, which ablates the model parameterization and loss weighting on unconditional CIFAR-10 (using the original, undistilled model).
Network Output Loss Weighting Stochastic sampler DDIM sampler (x, ε) combined SNR 2.54/9.88 2.78/9.56 Truncated SNR 2.47/9.85 2.76/9.49 SNR+1 2.52/9.79 2.87/9.45 x SNR 2.65/9.80 2.75/9.56 Truncated SNR 2.53/9.92 2.51/9.58 SNR+1 2.56/9.84 2.65/9.52 ε SNR 2.59/9.84 2.91/9.52 Truncated SNR N/A N/A SNR+1 2.56/9.77 3.27/9.41 v SNR 2.65/9.86 3.05/9.56 Truncated SNR 2.45/9.80 2.75/9.52 SNR+1 2.49/9.77 2.87/9.43 (Metrics are reported as FID/IS. Lower FID is better, higher IS is better.)
This ablation study confirms the paper's claims:
- The proposed stable parameterizations (predicting , , or a combination) all achieve excellent performance, comparable to or even better than the standard -prediction.
- Predicting directly performed slightly better empirically in this specific study.
- The standard -prediction combined with
Truncated SNRloss was unstable and diverged (N/A), highlighting the instability issues discussed in the paper. - The
Truncated SNRand loss weightings, which do not ignore the low-SNR regime, are effective and crucial for the distillation process.
7. Conclusion & Reflections
-
Conclusion Summary: The paper successfully introduces
progressive distillation, a novel and highly effective method for dramatically accelerating the sampling process of diffusion models. By combining this algorithm with more stable model parameterizations, the authors demonstrate the ability to reduce sampling steps by orders of magnitude (from thousands to as few as four) with minimal loss in image quality. This work effectively bridges a major gap between the high performance of diffusion models and their practical usability, making them a much more viable solution for real-world generative tasks. -
Limitations & Future Work:
- The authors note that the student model in their experiments always had the same architecture and size as the teacher. A potential future direction is to explore distilling into a smaller student model, which would provide further gains in computational efficiency at test time.
- The work is focused on image generation. The authors suggest that
progressive distillationcould be applied to diffusion models for other data modalities, such as audio or text. - While quality is high at 4 steps, there is a noticeable drop-off at 1-2 steps. Further research could investigate ways to improve performance in this extreme few-shot sampling regime.
-
Personal Insights & Critique:
- The core idea of "amortizing" the multi-step integration of the ODE into the network's parameters via distillation is elegant and powerful. It reframes the problem from finding a better solver to creating a better (simpler) problem to solve.
- The "progressive" aspect of the algorithm is a key innovation. It allows the method to scale to very slow teacher models with thousands of steps without incurring a prohibitive cost, a significant advantage over prior distillation work.
- The paper is a strong example of rigorous academic research: it clearly identifies a critical problem, proposes a well-motivated solution, validates it with thorough experiments and ablations, and provides clear, reproducible details.
- This work had a major impact on the field, making diffusion models practical for many more applications. The ability to generate high-fidelity images in a handful of steps was a game-changer and spurred further research into few-step and even single-step generation, which is now a very active area. The techniques presented here have become a standard part of the toolbox for anyone working on efficient diffusion models.
Similar papers
Recommended via semantic vector search.