Source code for world_models.models.diffusion.diamond_diffusion

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math


[docs] class AdaptiveGroupNorm(nn.Module): """Adaptive Group Normalization that conditions on actions and diffusion time.""" def __init__(self, num_groups: int, num_channels: int, cond_dim: int): super().__init__() self.num_groups = num_groups self.num_channels = num_channels self.norm = nn.GroupNorm(num_groups, num_channels, affine=False) self.linear = nn.Linear(cond_dim, 2 * num_channels)
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: """ Args: x: Input tensor [B, C, H, W] cond: Conditioning tensor [B, cond_dim] """ x = self.norm(x) scale, bias = self.linear(cond).chunk(2, dim=-1) scale = scale.unsqueeze(-1).unsqueeze(-1) bias = bias.unsqueeze(-1).unsqueeze(-1) return x * (1 + scale) + bias
[docs] class ResBlock(nn.Module): """Residual block with adaptive group normalization.""" def __init__( self, in_channels: int, out_channels: int, cond_dim: int, dropout: float = 0.0, ): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) self.norm1 = AdaptiveGroupNorm(32, in_channels, cond_dim) self.norm2 = AdaptiveGroupNorm(32, out_channels, cond_dim) self.dropout = nn.Dropout(dropout) # skip connection may be either a Conv2d or an Identity - annotate as # generic nn.Module to satisfy static type checkers. self.skip: nn.Module if in_channels != out_channels: self.skip = nn.Conv2d(in_channels, out_channels, 1) else: self.skip = nn.Identity() self.act = nn.SiLU()
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: h = self.act(self.norm1(x, cond)) h = self.conv1(h) h = self.act(self.norm2(h, cond)) h = self.dropout(h) h = self.conv2(h) return h + self.skip(x)
[docs] class AttentionBlock(nn.Module): """Self-attention block for U-Net.""" def __init__(self, channels: int, cond_dim: int): super().__init__() self.channels = channels self.norm = AdaptiveGroupNorm(32, channels, cond_dim) self.qkv = nn.Linear(channels, channels * 3) self.proj = nn.Linear(channels, channels)
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape h = self.norm(x, cond) h = h.reshape(B, C, H * W).permute(0, 2, 1) qkv = self.qkv(h) q, k, v = qkv.chunk(3, dim=-1) attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(C) attn = F.softmax(attn, dim=-1) h = torch.matmul(attn, v) h = self.proj(h) h = h.permute(0, 2, 1).reshape(B, C, H, W) return x + h
[docs] class TimestepEmbedding(nn.Module): """Sinusoidal timestep embedding.""" def __init__(self, dim: int, freq_dim: int = 256): super().__init__() self.mlp = nn.Sequential( nn.Linear(1, freq_dim), nn.SiLU(), nn.Linear(freq_dim, dim), )
[docs] def forward(self, t: torch.Tensor) -> torch.Tensor: t = t.view(-1, 1) return self.mlp(t)
[docs] class DownBlock(nn.Module): """Downsampling block for U-Net encoder.""" def __init__( self, in_channels: int, out_channels: int, cond_dim: int, num_res_blocks: int = 2, attention: bool = False, ): super().__init__() self.res_blocks = nn.ModuleList() for _ in range(num_res_blocks): self.res_blocks.append(ResBlock(in_channels, out_channels, cond_dim)) in_channels = out_channels # attn can be AttentionBlock or None self.attn: Optional[AttentionBlock] if attention: self.attn = AttentionBlock(out_channels, cond_dim) else: self.attn = None
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: for res_block in self.res_blocks: x = res_block(x, cond) if self.attn is not None: x = self.attn(x, cond) return x
[docs] class UpBlock(nn.Module): """Upsampling block for U-Net decoder.""" def __init__( self, in_channels: int, out_channels: int, cond_dim: int, num_res_blocks: int = 2, attention: bool = False, ): super().__init__() self.res_blocks = nn.ModuleList() for _ in range(num_res_blocks): self.res_blocks.append(ResBlock(in_channels, out_channels, cond_dim)) in_channels = out_channels # attn can be AttentionBlock or None self.attn: Optional[AttentionBlock] if attention: self.attn = AttentionBlock(out_channels, cond_dim) else: self.attn = None
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: for res_block in self.res_blocks: x = res_block(x, cond) if self.attn is not None: x = self.attn(x, cond) return x
[docs] class DiffusionUNet(nn.Module): """ U-Net architecture for EDM diffusion world model. Uses frame stacking for observation conditioning and adaptive group norm for action conditioning. """ def __init__( self, obs_channels: int = 3, num_conditioning_frames: int = 4, base_channels: int = 64, channel_multipliers: Tuple[int, ...] = (1, 1, 1, 1), num_res_blocks: int = 2, cond_dim: int = 256, action_dim: int = 18, ): super().__init__() self.num_conditioning_frames = num_conditioning_frames self.obs_channels = obs_channels self.input_conv = nn.Conv2d( obs_channels * (num_conditioning_frames + 1), base_channels, 3, padding=1 ) self.time_embed = TimestepEmbedding(cond_dim) self.action_embed = nn.Embedding(action_dim, cond_dim) self.down_blocks = nn.ModuleList() in_ch = base_channels for i, mult in enumerate(channel_multipliers): out_ch = base_channels * mult self.down_blocks.append( DownBlock( in_ch, out_ch, cond_dim, num_res_blocks, attention=False, ) ) in_ch = out_ch self.middle_block = nn.ModuleList( [ ResBlock(in_ch, in_ch, cond_dim), AttentionBlock(in_ch, cond_dim), ResBlock(in_ch, in_ch, cond_dim), ] ) self.up_blocks = nn.ModuleList() for i, mult in enumerate(reversed(channel_multipliers)): out_ch = base_channels * mult self.up_blocks.append( UpBlock( in_ch, out_ch, cond_dim, num_res_blocks, attention=False, ) ) in_ch = out_ch self.output_conv = nn.Sequential( nn.Conv2d(base_channels, base_channels, 3, padding=1), nn.SiLU(), nn.Conv2d(base_channels, obs_channels, 3, padding=1), )
[docs] def forward( self, x: torch.Tensor, t: torch.Tensor, obs_history: torch.Tensor, actions: torch.Tensor, ) -> torch.Tensor: """ Forward pass of the diffusion model. Args: x: Noised observation at timestep t [B, C, H, W] t: Diffusion timestep [B] obs_history: Past observations for conditioning [B, L, C, H, W] actions: Past actions [B, L] Returns: Predicted clean observation [B, C, H, W] """ B = x.shape[0] # Normalize obs_history to [B, L, C, H, W] if obs_history.ndim == 5: # possible layouts: [B, L, C, H, W] or [B, L, H, W, C] if obs_history.shape[2] == self.obs_channels: bh = obs_history elif obs_history.shape[-1] == self.obs_channels: bh = obs_history.permute(0, 1, 4, 2, 3).contiguous() else: # fallback: try to reshape if already flattened bh = obs_history.contiguous() else: bh = obs_history.contiguous() # bh is [B, L, C, H, W] -> reshape to [B, L*C, H, W] if bh.ndim == 5: L = bh.shape[1] C = bh.shape[2] H = bh.shape[3] W = bh.shape[4] bh_flat = bh.reshape(B, L * C, H, W) else: bh_flat = bh.reshape(B, -1, bh.shape[-2], bh.shape[-1]) # ensure x is [B, C, H, W] if ( x.ndim == 4 and x.shape[1] != self.obs_channels and x.shape[-1] == self.obs_channels ): # maybe x is [B, H, W, C] x = x.permute(0, 3, 1, 2).contiguous() # concat: channels should equal obs_channels * (L + 1) x = torch.cat([x, bh_flat], dim=1) # final sanity check expected_in = self.input_conv.in_channels if x.shape[1] != expected_in: raise RuntimeError( f"DiffusionUNet input channel mismatch: got {x.shape[1]}, expected {expected_in}" ) h = self.input_conv(x) t_emb = self.time_embed(t) action_emb = self.action_embed(actions.long()) cond = t_emb + action_emb.sum(dim=1) skip_connections = [] for down_block in self.down_blocks: h = down_block(h, cond) skip_connections.append(h) h = F.avg_pool2d(h, 2) for block in self.middle_block: if isinstance(block, AttentionBlock): h = block(h, cond) else: h = block(h, cond) for up_block in self.up_blocks: h = F.interpolate(h, scale_factor=2, mode="nearest") h = up_block(h, cond) return self.output_conv(h)
[docs] class EDMPreconditioner: """EDM preconditioner following Karras et al. (2022).""" def __init__( self, sigma_data: float = 0.5, p_mean: float = -0.4, p_std: float = 1.2, ): self.sigma_data = sigma_data self.p_mean = p_mean self.p_std = p_std
[docs] def get_preconditioners(self, sigma: torch.Tensor) -> dict: """ Compute EDM preconditioners for given noise levels. Returns: Dictionary with c_skip, c_out, c_in, c_noise """ c_skip = (self.sigma_data**2) / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() c_noise = torch.log(sigma) / 4.0 return { "c_skip": c_skip, "c_out": c_out, "c_in": c_in, "c_noise": c_noise, }
[docs] def sample_noise_level(self, batch_size: int, device: torch.device) -> torch.Tensor: """Sample noise level from log-normal distribution.""" log_sigma = torch.randn(batch_size, device=device) * self.p_std + self.p_mean sigma = torch.exp(log_sigma) return sigma
[docs] def denoise( self, model, x: torch.Tensor, sigma: torch.Tensor, **kwargs ) -> torch.Tensor: """ Apply EDM denoising with preconditioners. Args: model: Diffusion model x: Noised input [B, C, H, W] sigma: Noise level [B] **kwargs: Additional conditioning (obs_history, actions) Returns: Denoised prediction [B, C, H, W] """ sigma = sigma.view(-1, 1, 1, 1) precond = self.get_preconditioners(sigma) model_input = precond["c_in"] * x # use EDM's c_noise (log-sigma transform) as the network timestep t_cond = precond["c_noise"].squeeze(-1).squeeze(-1) model_output = model(model_input, t_cond, **kwargs) denoised = precond["c_skip"] * x + precond["c_out"] * model_output return denoised
[docs] class EulerSampler: """Euler method sampler for reverse diffusion.""" def __init__( self, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, num_steps: int = 3, edm_precond: Optional[EDMPreconditioner] = None, ): self.sigma_min = sigma_min self.sigma_max = sigma_max self.rho = rho self.num_steps = num_steps self.step_indices = torch.arange(num_steps) t_steps = ( sigma_max ** (1 / rho) + self.step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho self.t_steps = torch.flip(t_steps, dims=(0,)) self.t_next = torch.cat([self.t_steps[1:], torch.tensor([0.0])]) # attach EDM preconditioner instance (use provided or default) self.edm_precond = ( edm_precond if edm_precond is not None else EDMPreconditioner() )
[docs] @torch.no_grad() def sample( self, model: nn.Module, shape: Tuple[int, ...], device: torch.device, obs_history: Optional[torch.Tensor] = None, actions: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Generate samples using Euler method. Args: model: Diffusion model shape: Output shape [B, C, H, W] device: Device to run on obs_history: Conditioning observations [B, L, C, H, W] actions: Conditioning actions [B, L] Returns: Generated samples [B, C, H, W] """ B = shape[0] # Ensure t_steps and t_next are on the same device and dtype as model inputs. # self.t_steps and self.t_next are created on CPU in __init__, so move/cast here. t_steps = self.t_steps.to(device=device, dtype=torch.get_default_dtype()) t_next = self.t_next.to(device=device, dtype=torch.get_default_dtype()) x = torch.randn(shape, device=device) * t_steps[0] # Use the preconditioner instance attached to the sampler (injected # at construction time) to ensure a single canonical implementation # of the EDM preconditioning is used across training and sampling. edm_precond = self.edm_precond for i in range(self.num_steps): t_cur = t_steps[i].expand(B).to(device) t_nxt = t_next[i].expand(B).to(device) sigma_cur = t_cur.view(-1, 1, 1, 1) # Apply EDM preconditioning: model was trained on inputs scaled by # c_in and expected to output the network prediction which is then # combined with c_skip and c_out to form the denoised image. Use # the same transforms at sampling time. # Use the canonical EDM preconditioner denoise helper so all # preconditioning (c_in/c_out/c_skip) and time-conditioning # (c_noise) logic is centralized in one place. denoised = edm_precond.denoise( model, x, t_cur, obs_history=obs_history, actions=actions, ) # Euler step for the probability-flow ODE (deterministic sampler) d_cur = (denoised - x) / sigma_cur x = denoised + (t_nxt.view(-1) - t_cur.view(-1)).view(-1, 1, 1, 1) * d_cur return x.clamp(0, 1)