扩散模型推理优化
ViperEkura Lv1

1. DDPM

1.1. 前向扩散过程

给定真实数据样本 ,DDPM 定义一个固定的、参数化的马尔可夫链,在 步内将数据逐渐转化为标准高斯噪声:

其中:

  • 是预设的小方差(通常随 缓慢增大)
  • 整个前向过程的联合分布为:

定义累积量:

则可以直接采样任意时刻 (重参数化技巧):

即:

1.2. 反向生成过程

目标是学习一个可学习的马尔可夫链 来逆转前向过程:

在 DDPM 中,通常固定方差(例如 ,其中 ),只学习均值

关键洞察:利用贝叶斯规则和高斯性质,可以推导出真实后验 也是高斯分布:

其中:

由于 未知,DDPM 训练神经网络 来预测前向过程中加入的噪声

,可得:

代入 ,得到用 表示的均值:

(这是 DDPM 论文中常用的等价形式)

1.3. 训练目标

DDPM 最小化负对数似然的变分下界(ELBO),但通过重参数化可简化为噪声预测的均方误差

其中:

这个损失函数非常简单且易于优化。

1.4. 采样算法

训练完成后,从纯噪声开始反向生成:

  1. 采样


  2. 其中 (或设为

1.5. 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class GaussianDiffusion:
def __init__(
self,
timesteps=1000,
beta_schedule='linear'
):
self.timesteps = timesteps

if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
self.betas = betas

self.alphas = 1. - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)

# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)

# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)

def _extract(self, a: Tensor, t: Tensor, x_shape):
# get the param of given timestep t
batch_size = t.shape[0]
out = a.to(t.device).gather(0, t).float()
out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
return out

def q_sample(self, x_start: Tensor, t: Tensor, noise=None):
# forward diffusion (using the nice property): q(x_t | x_0)
if noise is None:
noise = torch.randn_like(x_start)

sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)

return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

def q_mean_variance(self, x_start: Tensor, t: Tensor):
# Get the mean and variance of q(x_t | x_0).
mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance

def q_posterior_mean_variance(self, x_start: Tensor, x_t: Tensor, t: Tensor):
# Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0)
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def predict_start_from_noise(self, x_t: Tensor, t: Tensor, noise: Tensor):
# compute x_0 from x_t and pred noise: the reverse of `q_sample`
return (
self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)

def p_mean_variance(self, model, x_t: Tensor, t: Tensor, clip_denoised=True):
# compute predicted mean and variance of p(x_{t-1} | x_t)
# predict noise using model
pred_noise = model(x_t, t)
# get the predicted x_0: different from the algorithm2 in the paper
x_recon = self.predict_start_from_noise(x_t, t, pred_noise)
if clip_denoised:
x_recon = torch.clamp(x_recon, min=-1., max=1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_recon, x_t, t)
return model_mean, posterior_variance, posterior_log_variance

@torch.no_grad()
def p_sample(self, model, x_t: Tensor, t: Tensor, clip_denoised=True):
# denoise_step: sample x_{t-1} from x_t and pred_noise
# predict mean and variance
model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised)
noise = torch.randn_like(x_t)
# no noise when t == 0
nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
# compute x_{t-1}
pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred_img

@torch.no_grad()
def sample(
self,
model: nn.Module,
image_size,
batch_size=8,
channels=3
):
# denoise: reverse diffusion
shape = (batch_size, channels, image_size, image_size)
device = next(model.parameters()).device
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device) # x_T ~ N(0, 1)
imgs = []
for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
img = self.p_sample(model, img, t)
imgs.append(img)
return imgs

def train_losses(self, model, x_start: Tensor, t: Tensor):
# compute train losses
noise = torch.randn_like(x_start) # random noise ~ N(0, 1)
x_noisy = self.q_sample(x_start, t, noise=noise) # x_t ~ q(x_t | x_0)
predicted_noise = model(x_noisy, t) # predict noise from noisy image
loss = F.mse_loss(noise, predicted_noise)
return loss

2. DDIM

传统DDPM(Denoising Diffusion Probabilistic Models)需要通过1000步以上的迭代步进行采样, 而DDIM(Denoising Diffusion Implicit Models))可压缩至 20–50 步而不显著损失质量, 可以实现跳步采样进行加速。DDIM 之所以允许任意时间子序列采样(即跳步采样,如从 ),其根本原因在于:DDIM 的反向过程不依赖马尔可夫性,而是直接建模任意两个时间步之间的确定性或可控随机映射,且该映射仅依赖于对原始数据 的估计。

下面我们结合公式严格解释这一性质。

2.1. 核心前提:前向过程的“任意时刻可直达”性质

在扩散模型中,前向过程是可解析计算任意 时刻分布的:

这意味着,给定 ,我们可以直接生成任意 ,无需经过中间步骤
更重要的是,这个关系是双向可逆的(在已知 或能估计 的前提下)。

2.2. DDIM 的关键洞察:任意两步之间的“一致性路径”

DDIM 不再假设反向过程必须满足:

(这是马尔可夫假设)

而是考虑更一般的设定:给定当前状态 ,我们想一步跳到任意更早的时间步 ),并希望这个跳跃仍然与原始数据分布一致。

为此,DDIM 利用如下事实:

如果我们知道(或能估计)真实的 ,那么 都是 的带噪版本,它们之间存在一个确定性的几何关系

具体地,由前向过程:

注意:同一个 被用于生成所有 (这是重参数化的核心)。

于是,我们可以消去 ,得到 关于 的表达式:

但实际中 未知,我们用神经网络预测它:,并由此估计:

然后代入 ,就得到从 一步跳到 的更新规则:

其中 控制跳跃中的随机性(当 时为完全确定性)。

关键点:这个公式只依赖于当前 和目标时间步 ,**不依赖中间任何 **。因此,我们可以自由选择任意子序列 进行采样。

2.3. 为什么 DDPM 不能跳步?

DDPM 的反向过程被严格定义为马尔可夫链

其推导依赖于真实后验 的高斯形式,而该后验本身是基于相邻两步的转移(即 )通过贝叶斯法则得到的。

若试图从 直接跳到 ,则:

  • 没有对应的 的简单闭式(除非重新推导)
  • 更重要的是,DDPM 的训练目标和方差设计(如 只保证相邻步的 KL 散度最小化,跨步使用会导致分布偏移

因此,DDPM 的反向过程不具备跨步一致性

2.4. 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
class DDIM(GaussianDiffusion):
"""
Denoising Diffusion Implicit Models (DDIM) sampler.
Inherits from GaussianDiffusion and adds a DDIM sampling method.
"""
def __init__(self, timesteps=1000, beta_schedule='linear'):
super().__init__(timesteps, beta_schedule)

@torch.no_grad()
def ddim_step(self, img, t, t_next, eta=0.0, clip_denoised=True, model=None, pred_noise=None):
"""
Perform a single DDIM sampling step from time t to t_next.

Args:
img: current image tensor x_t (batch, channels, height, width)
t: current timestep (int, larger than t_next)
t_next: next timestep (int, smaller than t)
eta: stochasticity parameter (0 -> deterministic, 1 -> DDPM-like)
clip_denoised: whether to clip predicted x0 to [-1, 1]
model: noise prediction model (optional, usually a UNet)
pred_noise: predicted noise (optional, if not provided, it will be computed)

Returns:
img: updated image tensor x_{t_next}
"""
device = img.device
batch_size = img.shape[0]

# Prepare batch tensors for indexing
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
t_next_batch = torch.full((batch_size,), t_next, device=device, dtype=torch.long)

# 1. Predict noise using the model
if pred_noise is None:
pred_noise = model(img, t_batch)

# 2. Extract cumulative alphas for current and next step
alpha_t = self._extract(self.alphas_cumprod, t_batch, img.shape) # ᾱ_t
alpha_next = self._extract(self.alphas_cumprod, t_next_batch, img.shape) # ᾱ_s

# 3. Compute sigma (stochastic component)
# sigma = η * √((1-ᾱ_s)/(1-ᾱ_t)) * √(1-ᾱ_t/ᾱ_s)
sigma = eta * torch.sqrt((1 - alpha_next) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_next)
sigma = torch.clamp(sigma, min=0.0)

# 4. Predict x0 from current x_t and predicted noise
pred_x0 = (img - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)
if clip_denoised:
pred_x0 = torch.clamp(pred_x0, -1., 1.)

# 5. Compute direction pointing to x_t (the "predicted" part)
dir_coeff = torch.sqrt(torch.clamp(1 - alpha_next - sigma**2, min=0.0))
dir_xt = dir_coeff * pred_noise

# 6. Generate random noise if eta > 0, otherwise zero
noise = torch.randn_like(img) if eta > 0 else 0.0

# 7. Update to x_s (x_{t_next})
img = torch.sqrt(alpha_next) * pred_x0 + dir_xt + sigma * noise

return img

@torch.no_grad()
def sample(
self,
model: nn.Module,
image_size,
batch_size=8,
channels=3,
n_steps=50,
eta=0.0,
clip_denoised=True
):
"""
Sample using DDIM with a reduced number of steps.

Args:
model: noise prediction model (usually a UNet)
image_size: spatial size of images (height = width)
batch_size: number of images to sample
channels: number of image channels
n_steps: number of sampling steps (must be <= timesteps)
eta: stochasticity parameter (0 -> deterministic, 1 -> DDPM-like)
clip_denoised: whether to clip predicted x0 to [-1, 1]

Returns:
List of images at each DDIM step (including the final result).
"""
device = next(model.parameters()).device
shape = (batch_size, channels, image_size, image_size)

# Create a sequence of timesteps from T to 0, spaced uniformly
step_indices = torch.linspace(0, self.timesteps - 1, n_steps, dtype=torch.long, device=device)
timesteps = torch.flip(step_indices, dims=[0]) # reverse to go from T down to 0

# Start from pure noise
img = torch.randn(shape, device=device) # x_T
imgs = [img]

# Iterate over the sequence, except the last step (t=0) which gives the final image
for i in tqdm(range(len(timesteps) - 1), desc='DDIM sampling', total=len(timesteps)-1):
t = timesteps[i].item() # current step (scalar)
t_next = timesteps[i + 1].item() # next step (scalar, smaller)

# Perform one DDIM step
img = self.ddim_step(img, t, t_next, eta, clip_denoised, model=model)

imgs.append(img)

return imgs

3. DeepCache

DeepCache是一种无需训练的扩散模型采样加速方法,通过缓存和复用UNet中间层的特征来减少计算量。它利用扩散模型去噪过程中特征变化的时序平滑性,在相邻时间步之间共享深层特征,从而大幅提升采样速度。

3.1. 核心洞察

在扩散模型的反向去噪过程中,观察到以下现象:

  1. 特征变化的时间平滑性:相邻时间步的UNet特征高度相似,尤其是深层特征变化缓慢
  2. 计算冗余:标准采样中,每个时间步都完整计算整个UNet,但相邻步的特征重复度很高
  3. 层级差异
    • 浅层特征(高分辨率):变化剧烈,包含细节信息
    • 深层特征(低分辨率):变化平缓,包含语义信息

DeepCache的核心思想是:在部分时间步完整计算UNet并缓存深层特征,在后续时间步复用这些缓存特征,跳过部分层的计算

3.2. 方法原理

标准UNet结构

典型的扩散模型UNet包含三个主要阶段:

  • 下采样阶段(Down blocks):逐步降低分辨率,提取特征
  • 中间阶段(Middle block):瓶颈层,处理最抽象的特征
  • 上采样阶段(Up blocks):逐步恢复分辨率,通过跳跃连接融合下采样特征

缓存策略

DeepCache定义两种前向模式:

1. 完整计算步(Full step)

  • 步执行一次( 为缓存间隔)
  • 完整计算整个UNet
  • 缓存指定下采样层的输出特征

2. 缓存复用步(Cache step)

  • 其余 步执行
  • 从缓存读取指定层的特征,跳过这些层的计算
  • 其他层正常计算

对于时间步 ,若 (缓存步):

(完整步):

3.3. 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class UnetWrapper(nn.Module):
def __init__(self, unet_model: UNet, cache_interval=0, init_fwd_step=0):
super().__init__()
self.unet = unet_model
self.cache_interval = cache_interval
self.fwd_step = init_fwd_step

def _forward_full(self, x: Tensor, timesteps: Tensor):
hs = []

# down stage
h = x
t = self.unet.time_embed(timestep_embedding(timesteps, self.unet.model_channels))
for module in self.unet.down_blocks:
h = module(h, t)
hs.append(h)

# middle stage
h = self.unet.middle_block(h, t)

# up stage
for module in self.unet.up_blocks:
cat_in = torch.cat([h, hs.pop()], dim=1)
h = module(cat_in, t)

return self.unet.out(h)

def _forward_with_cache(self, x: Tensor, timesteps: Tensor):
raise NotImplementedError


def forward(self, x: Tensor, timesteps: Tensor):

if self.fwd_step % self.cache_interval == 0:
res = self._forward_full(x, timesteps)
else:
res = self._forward_with_cache(x, timesteps)

self.fwd_step += 1
return res


class DeepCacheWrapper(UnetWrapper):
def __init__(
self,
unet_model: UNet,
cache_interval:int=0,
init_fwd_step:int=0,
cache_layer_list:Optional[List]=None
):

super().__init__(unet_model, cache_interval, init_fwd_step)
self.cache_layer_list = cache_layer_list
self.cache_features: Dict[int, Tensor] = {}
self._init_cache_ids()

def _init_cache_ids(self):
if self.cache_layer_list:
# copy form init parameter
self.cache_layer_list = self.cache_layer_list.copy()
else:
# all downsample layers
self.cache_layer_list = []
for i in range(len(self.unet.down_blocks)):
self.cache_layer_list.append(i)

def _forward_full(self, x: Tensor, timesteps: Tensor) -> Tensor:
hs = []

# time embedding
t = self.unet.time_embed(timestep_embedding(timesteps, self.unet.model_channels))

# down stage
h = x
for idx, module in enumerate(self.unet.down_blocks):
h = module(h, t)
hs.append(h)

if idx in self.cache_layer_list:
self.cache_features[idx] = h.clone()

# middle stage
h = self.unet.middle_block(h, t)

# up stage
for module in self.unet.up_blocks:
cat_in = torch.cat([h, hs.pop()], dim=1)
h = module(cat_in, t)

return self.unet.out(h)

def _forward_with_cache(self, x: Tensor, timesteps: Tensor):
hs = []

# down stage
h = x
t = self.unet.time_embed(timestep_embedding(timesteps, self.unet.model_channels))

for idx, module in enumerate(self.unet.down_blocks):
if idx in self.cache_layer_list:
cached_h = self.cache_features[idx]
h = cached_h
else:
h = module(h, t)

hs.append(h)

# middle stage
h = self.unet.middle_block(h, t)

# up stage
for module in self.unet.up_blocks:
cat_in = torch.cat([h, hs.pop()], dim=1)
h = module(cat_in, t)

return self.unet.out(h)
 REWARD AUTHOR