1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
| class GaussianDiffusion: def __init__( self, timesteps=1000, beta_schedule='linear' ): self.timesteps = timesteps
if beta_schedule == 'linear': betas = linear_beta_schedule(timesteps) elif beta_schedule == 'cosine': betas = cosine_beta_schedule(timesteps) else: raise ValueError(f'unknown beta schedule {beta_schedule}') self.betas = betas
self.alphas = 1. - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20)) self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
def _extract(self, a: Tensor, t: Tensor, x_shape): batch_size = t.shape[0] out = a.to(t.device).gather(0, t).float() out = out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) return out
def q_sample(self, x_start: Tensor, t: Tensor, noise=None): if noise is None: noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def q_mean_variance(self, x_start: Tensor, t: Tensor): mean = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = self._extract(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance
def q_posterior_mean_variance(self, x_start: Tensor, x_t: Tensor, t: Tensor): posterior_mean = ( self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped
def predict_start_from_noise(self, x_t: Tensor, t: Tensor, noise: Tensor): return ( self._extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise )
def p_mean_variance(self, model, x_t: Tensor, t: Tensor, clip_denoised=True): pred_noise = model(x_t, t) x_recon = self.predict_start_from_noise(x_t, t, pred_noise) if clip_denoised: x_recon = torch.clamp(x_recon, min=-1., max=1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_recon, x_t, t) return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() def p_sample(self, model, x_t: Tensor, t: Tensor, clip_denoised=True): model_mean, _, model_log_variance = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised) noise = torch.randn_like(x_t) nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) pred_img = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return pred_img
@torch.no_grad() def sample( self, model: nn.Module, image_size, batch_size=8, channels=3 ): shape = (batch_size, channels, image_size, image_size) device = next(model.parameters()).device img = torch.randn(shape, device=device) imgs = [] for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): t = torch.full((batch_size,), i, device=device, dtype=torch.long) img = self.p_sample(model, img, t) imgs.append(img) return imgs
def train_losses(self, model, x_start: Tensor, t: Tensor): noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start, t, noise=noise) predicted_noise = model(x_noisy, t) loss = F.mse_loss(noise, predicted_noise) return loss
|