Source code for world_models.vision.iris_decoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

from world_models.vision.iris_encoder import IRISEncoder


[docs] class IRISDecoder(nn.Module): """CNN Decoder for IRIS discrete autoencoder. Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64. """ def __init__( self, vocab_size: int = 512, embedding_dim: int = 512, base_channels: int = 32, out_channels: int = 3, frame_shape: Tuple[int, int, int] = (3, 64, 64), ): super().__init__() self.embedding_dim = embedding_dim # optional mapping from token indices -> embeddings so decoder can # accept discrete token indices directly. This provides a convenient # API for sampling/decoding from predicted tokens. self.vocab_size = vocab_size self.index_to_embedding = nn.Embedding(self.vocab_size, self.embedding_dim) self.frame_shape = frame_shape self.out_channels = out_channels # Input projection self.input_proj = nn.Conv2d(embedding_dim, embedding_dim, 1) # Residual blocks before upsampling self.residual_blocks = nn.Sequential( *[ResidualBlock(embedding_dim) for _ in range(2)] ) # Upsampling blocks (4 -> 8 -> 16 -> 32 -> 64) self.upsample_blocks = nn.ModuleList() # Block 1: 4x4 -> 8x8 self.upsample_blocks.append( UpsampleBlock(embedding_dim, base_channels * 8, base_channels * 4) ) # Block 2: 8x8 -> 16x16 self.upsample_blocks.append( UpsampleBlock(base_channels * 4, base_channels * 4, base_channels * 2) ) # Block 3: 16x16 -> 32x32 self.upsample_blocks.append( UpsampleBlock(base_channels * 2, base_channels * 2, base_channels) ) # Block 4: 32x32 -> 64x64 self.upsample_blocks.append( UpsampleBlock(base_channels, base_channels, base_channels) ) # Final output projection self.output_proj = nn.Conv2d(base_channels, out_channels, 3, padding=1)
[docs] def forward(self, z: torch.Tensor) -> torch.Tensor: """Decode tokens to images. Args: z: Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4) Returns: reconstructed: Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64) """ # Project input h = self.input_proj(z) # Residual blocks h = self.residual_blocks(h) # Upsampling for upsample_block in self.upsample_blocks: h = upsample_block(h) # Final output h = self.output_proj(h) # Ensure output matches frame shape _, _, out_h, out_w = h.shape target_h, target_w = self.frame_shape[1], self.frame_shape[2] if out_h != target_h or out_w != target_w: h = F.interpolate( h, size=(target_h, target_w), mode="bilinear", align_corners=False ) return h
[docs] def decode_from_embeddings(self, z_flat: torch.Tensor) -> torch.Tensor: """Decode flattened token embeddings to images. Args: z_flat: Flattened tokens (B, H*W, C) or (B, C, H, W) Returns: Reconstructed images """ if z_flat.dim() == 3: # (B, H*W, C) B, HW, C = z_flat.shape H = W = int(HW**0.5) z = z_flat.permute(0, 2, 1).reshape(B, C, H, W) else: z = z_flat return self.forward(z)
[docs] def decode_from_indices(self, indices: torch.Tensor) -> torch.Tensor: """Decode discrete token indices into images. Args: indices: Tensor of shape (B, H, W) or (B, H*W) containing integer token indices in the range [0, vocab_size). Returns: Reconstructed images (B, C, H, W) """ if indices.dim() == 3: B, H, W = indices.shape flat = indices.view(B, -1) HW = H * W elif indices.dim() == 2: B, HW = indices.shape H = W = int(HW**0.5) flat = indices else: raise ValueError("indices must be shape (B, H, W) or (B, H*W)") # (B, HW, C) emb = self.index_to_embedding(flat) # convert to (B, C, H, W) emb = emb.permute(0, 2, 1).reshape(B, self.embedding_dim, H, W) return self.forward(emb)
[docs] class UpsampleBlock(nn.Module): """Upsampling block with optional residual connection.""" def __init__(self, in_channels: int, mid_channels: int, out_channels: int): super().__init__() self.block = nn.Sequential( nn.ReLU(), nn.Conv2d(in_channels, mid_channels, 3, padding=1), nn.ReLU(), nn.Conv2d(mid_channels, out_channels, 3, padding=1), ) # Skip connection projection if needed # Use the broad Module type because we may assign either Identity or Conv2d here. self.skip: nn.Module = nn.Identity() if in_channels != out_channels: self.skip = nn.Conv2d(in_channels, out_channels, 1) # Upsample self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.upsample(x) return self.skip(x) + self.block(x)
[docs] class ResidualBlock(nn.Module): """Residual block for decoder.""" def __init__(self, channels: int): super().__init__() self.block = nn.Sequential( nn.ReLU(), nn.Conv2d(channels, channels, 3, padding=1), nn.ReLU(), nn.Conv2d(channels, channels, 1), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.block(x)
[docs] class DiscreteAutoencoder(nn.Module): """Complete Discrete Autoencoder combining encoder and decoder. Used for training the VQVAE component of IRIS. """ def __init__( self, vocab_size: int = 512, tokens_per_frame: int = 16, embedding_dim: int = 512, base_channels: int = 64, frame_shape: Tuple[int, int, int] = (3, 64, 64), ): super().__init__() self.encoder = IRISEncoder( vocab_size=vocab_size, tokens_per_frame=tokens_per_frame, embedding_dim=embedding_dim, in_channels=frame_shape[0], base_channels=base_channels, frame_shape=frame_shape, ) self.decoder = IRISDecoder( embedding_dim=embedding_dim, base_channels=32, # decoder uses smaller channels out_channels=frame_shape[0], frame_shape=frame_shape, )
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]: """Full encode-decode forward pass. Args: x: Input images (B, C, H, W) Returns: reconstruction: Reconstructed images indices: Token indices (B, H', W') loss_dict: Dictionary with loss components """ z_q, indices, vq_loss = self.encoder(x) # Decode (use detached z_q to stop gradient through decoder for VQ loss) self.decoder( z_q.detach() + z_q - z_q ) # identity with gradient stop for z_q part # Actually, we want gradients to flow through reconstruction path reconstruction_st = self.decoder(z_q) # Compute reconstruction loss recon_loss = F.l1_loss(reconstruction_st, x) # Combine losses loss = recon_loss + vq_loss["vq_loss"] loss_dict = { "reconstruction": recon_loss, "vq": vq_loss["vq_loss"], "perplexity": vq_loss["perplexity"], "total": loss, } return reconstruction_st, indices, loss_dict
[docs] def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode to token indices.""" return self.encoder.encode_to_indices(x)
[docs] def decode(self, indices: torch.Tensor) -> torch.Tensor: """Decode token indices to images.""" embeddings = self.encoder.decode_from_indices(indices) return self.decoder(embeddings)