Source code for world_models.models.diffusion.DDPM

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class DDPM(nn.Module): """Utility module implementing forward and reverse DDPM diffusion steps. Precomputes diffusion schedule terms and exposes helpers for noising training inputs (`q_sample`) and iterative denoising sampling (`sample`). """ def __init__(self, timesteps: int, beta_start: float, beta_end: float) -> None: super().__init__() self.timesteps = timesteps self.betas: torch.Tensor = torch.linspace(beta_start, beta_end, timesteps) alphas = 1.0 - self.betas self.alphas: torch.Tensor = alphas self.alphas_cumprod: torch.Tensor = torch.cumprod(alphas, dim=0) self.alphas_cumprod_prev: torch.Tensor = F.pad( self.alphas_cumprod[:-1], (1, 0), value=1.0 ) self.register_buffer("betas", self.betas) self.register_buffer("alphas", self.alphas) self.register_buffer("alphas_cumprod", self.alphas_cumprod) self.register_buffer("alphas_cumprod_prev", self.alphas_cumprod_prev) self.sqrt_alphas_cumprod: torch.Tensor = torch.sqrt(self.alphas_cumprod) self.register_buffer("sqrt_alphas_cumprod", self.sqrt_alphas_cumprod) self.sqrt_one_minus_alphas_cumprod: torch.Tensor = torch.sqrt( 1.0 - self.alphas_cumprod ) self.register_buffer( "sqrt_one_minus_alphas_cumprod", self.sqrt_one_minus_alphas_cumprod ) self.posterior_variance: torch.Tensor = ( self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.register_buffer("posterior_variance", self.posterior_variance)
[docs] def q_sample( self, x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor | None = None ) -> torch.Tensor: if noise is None: noise = torch.randn_like(x_start) s1 = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1) s2 = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) return s1 * x_start + s2 * noise
[docs] def p_sample( self, model: nn.Module, x_t: torch.Tensor, t: torch.Tensor ) -> torch.Tensor: # Predict noise eps = model(x_t, t) # Compute x0_hat a_t = self.alphas[t].view(-1, 1, 1, 1) ac_t = self.alphas_cumprod[t].view(-1, 1, 1, 1) sqrt_one_minus_ac = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) x0_hat = (x_t - sqrt_one_minus_ac * eps) / torch.sqrt(ac_t) # Compute Posterior mean beta_t = self.betas[t].view(-1, 1, 1, 1) ac_prev = self.alphas_cumprod_prev[t].view(-1, 1, 1, 1) coef1 = torch.sqrt(ac_prev) * beta_t / (1.0 - ac_t) coef2 = torch.sqrt(a_t) * (1.0 - ac_prev) / (1.0 - ac_t) mean = coef1 * x0_hat + coef2 * x_t # Add noise except for t == 0 var = self.posterior_variance[t].view(-1, 1, 1, 1) noise = torch.randn_like(x_t) if t[0] > 0 else torch.zeros_like(x_t) return mean + torch.sqrt(var) * noise
[docs] @torch.no_grad() def sample( self, model: nn.Module, n: int, img_size: int, channels: int ) -> torch.Tensor: x = torch.randn(n, channels, img_size, img_size).to( next(model.parameters()).device ) for i in reversed(range(self.timesteps)): t = torch.full((n,), i, dtype=torch.long).to(x.device) x = self.p_sample(model, x, t) return x.clamp(-1.0, 1.0)