Diffusion probabilistic models (DPMs) are emerging powerful generative models. Despite their high-quality generation performance, DPMs still suffer from their slow sampling as they generally need hundreds or thousands of sequential function evaluations (steps) of large neural networks to draw a sample. Sampling from DPMs can be viewed alternatively as solving the corresponding diffusion ordinary differential equations (ODEs). In this work, we propose an exact formulation of the solution of diffusion ODEs. The formulation analytically computes the linear part of the solution, rather than leaving all terms to black-box ODE solvers as adopted in previous works. By applying change-of-variable, the solution can be equivalently simplified to an exponentially weighted integral of the neural network. Based on our formulation, we propose DPM-Solver, a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. DPM-Solver is suitable for both discrete-time and continuous-time DPMs without any further training. Experimental results show that DPM-Solver can generate high-quality samples in only 10 to 20 function evaluations on various datasets. We achieve 4.70 FID in 10 function evaluations and 2.87 FID in 20 function evaluations on the CIFAR10 dataset, and a 4∼16× speedup compared with previous state-of-the-art training-free samplers on various datasets.
Source: DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps (2022-06-02). See: paper link.
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
class DiffusionModel():
def __init__(self, T, model, device):
self.T = T
self.function_approximator = model.to(device)
self.device = device
self.beta = torch.linspace(1e-4, 0.02, T, device=device)
self.alpha = 1.0 - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.alphas = torch.sqrt(self.alpha_bar)
self.sigmas = torch.sqrt(1.0 - self.alpha_bar)
self.lambdas = torch.log(self.alphas / self.sigmas)
def training(self, batch_size, optimizer):
pass # See https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/blob/main/Denoising_Diffusion_Probabilistic_Models/diffusion_models.py#L31
@torch.no_grad()
def sampling(self, n_samples=1, image_channels=1, img_size=(32, 32), use_tqdm=True):
pass # See https://github.com/MaximeVandegar/Papers-in-100-Lines-of-Code/blob/main/Denoising_Diffusion_Probabilistic_Models/diffusion_models.py#L54
@torch.no_grad()
def dpm_solver_sampling(self, n_samples=1, image_channels=1, img_size=(32, 32), n_steps=10, use_tqdm=True):
"""
DPM-Solver-2 (Algorithm 1 from https://arxiv.org/pdf/2206.00927).
"""
step_size = self.T // n_steps
# start from Gaussian noise x_T
xT = torch.randn((n_samples, image_channels, img_size[0], img_size[1]), device=self.device)
x_tilde = xT
for i in tqdm(range(n_steps), desc="DPM-Solver", disable=not use_tqdm):
t_prev = self.T - i * step_size
t_cur = max(t_prev - step_size, 1)
# midpoint in λ-space
lam_mid = (self.lambdas[t_prev - 1] + self.lambdas[t_cur - 1]) / 2.
# invert λ→t via nearest neighbor lookup
s_i = torch.argmin(torch.abs(self.lambdas - lam_mid)).item() + 1
# λ-step size
h = self.lambdas[t_cur - 1] - self.lambdas[t_prev - 1]
# model evaluation at t_prev
t_prev_tensor = torch.full((n_samples,), t_prev, dtype=torch.long, device=self.device)
u_i = (self.alphas[s_i - 1] / self.alphas[t_prev - 1]) * x_tilde - self.sigmas[s_i - 1] * (
torch.exp(h * 0.5) - 1) * self.function_approximator(x_tilde, t_prev_tensor - 1)
t_s_tensor = torch.full((n_samples,), s_i, dtype=torch.long, device=self.device)
x_tilde = (self.alphas[t_cur - 1] / self.alphas[t_prev - 1]) * x_tilde - self.sigmas[t_cur - 1] * (
torch.exp(h) - 1) * self.function_approximator(u_i, t_s_tensor - 1)
return x_tilde
if __name__ == "__main__":
model = torch.load('model_ddpm_mnist')
diffusion = DiffusionModel(1000, model, 'cuda')
nb_images = 81
samples = diffusion.dpm_solver_sampling(n_samples=nb_images, n_steps=10, use_tqdm=True)
plt.figure(figsize=(17, 17))
for i in range(nb_images):
plt.subplot(9, 9, 1+i)
plt.axis('off')
img = samples[i].squeeze().clamp(0, 1).cpu().numpy()
plt.imshow(img, cmap='gray')
plt.savefig('Imgs/dpm2_samples.png')
plt.close()
python implementation DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps in 100 lines
2022-06-02