Source code for world_models.vision.iris_decoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

from world_models.vision.iris_encoder import IRISEncoder


[docs] class IRISDecoder(nn.Module): """CNN Decoder for IRIS discrete autoencoder. Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64. """ def __init__( self, vocab_size: int = 512, embedding_dim: int = 512, base_channels: int = 32, out_channels: int = 3, frame_shape: Tuple[int, int, int] = (3, 64, 64), ): super().__init__() self.embedding_dim = embedding_dim self.frame_shape = frame_shape self.out_channels = out_channels # Input projection self.input_proj = nn.Conv2d(embedding_dim, embedding_dim, 1) # Residual blocks before upsampling self.residual_blocks = nn.Sequential( *[ResidualBlock(embedding_dim) for _ in range(2)] ) # Upsampling blocks (4 -> 8 -> 16 -> 32 -> 64) self.upsample_blocks = nn.ModuleList() # Block 1: 4x4 -> 8x8 self.upsample_blocks.append( UpsampleBlock(embedding_dim, base_channels * 8, base_channels * 4) ) # Block 2: 8x8 -> 16x16 self.upsample_blocks.append( UpsampleBlock(base_channels * 4, base_channels * 4, base_channels * 2) ) # Block 3: 16x16 -> 32x32 self.upsample_blocks.append( UpsampleBlock(base_channels * 2, base_channels * 2, base_channels) ) # Block 4: 32x32 -> 64x64 self.upsample_blocks.append( UpsampleBlock(base_channels, base_channels, base_channels) ) # Final output projection self.output_proj = nn.Conv2d(base_channels, out_channels, 3, padding=1)
[docs] def forward(self, z: torch.Tensor) -> torch.Tensor: """Decode tokens to images. Args: z: Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4) Returns: reconstructed: Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64) """ # Project input h = self.input_proj(z) # Residual blocks h = self.residual_blocks(h) # Upsampling for upsample_block in self.upsample_blocks: h = upsample_block(h) # Final output h = self.output_proj(h) # Ensure output matches frame shape _, _, out_h, out_w = h.shape target_h, target_w = self.frame_shape[1], self.frame_shape[2] if out_h != target_h or out_w != target_w: h = F.interpolate( h, size=(target_h, target_w), mode="bilinear", align_corners=False ) return h
[docs] def decode_from_embeddings(self, z_flat: torch.Tensor) -> torch.Tensor: """Decode flattened token embeddings to images. Args: z_flat: Flattened tokens (B, H*W, C) or (B, C, H, W) Returns: Reconstructed images """ if z_flat.dim() == 3: # (B, H*W, C) B, HW, C = z_flat.shape H = W = int(HW**0.5) z = z_flat.permute(0, 2, 1).reshape(B, C, H, W) else: z = z_flat return self.forward(z)
[docs] class UpsampleBlock(nn.Module): """Upsampling block with optional residual connection.""" def __init__(self, in_channels: int, mid_channels: int, out_channels: int): super().__init__() self.block = nn.Sequential( nn.ReLU(), nn.Conv2d(in_channels, mid_channels, 3, padding=1), nn.ReLU(), nn.Conv2d(mid_channels, out_channels, 3, padding=1), ) # Skip connection projection if needed self.skip = nn.Identity() if in_channels != out_channels: self.skip = nn.Conv2d(in_channels, out_channels, 1) # Upsample self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.upsample(x) return self.skip(x) + self.block(x)
[docs] class ResidualBlock(nn.Module): """Residual block for decoder.""" def __init__(self, channels: int): super().__init__() self.block = nn.Sequential( nn.ReLU(), nn.Conv2d(channels, channels, 3, padding=1), nn.ReLU(), nn.Conv2d(channels, channels, 1), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.block(x)
[docs] class DiscreteAutoencoder(nn.Module): """Complete Discrete Autoencoder combining encoder and decoder. Used for training the VQVAE component of IRIS. """ def __init__( self, vocab_size: int = 512, tokens_per_frame: int = 16, embedding_dim: int = 512, base_channels: int = 64, frame_shape: Tuple[int, int, int] = (3, 64, 64), ): super().__init__() self.encoder = IRISEncoder( vocab_size=vocab_size, tokens_per_frame=tokens_per_frame, embedding_dim=embedding_dim, in_channels=frame_shape[0], base_channels=base_channels, frame_shape=frame_shape, ) self.decoder = IRISDecoder( vocab_size=vocab_size, embedding_dim=embedding_dim, base_channels=32, # decoder uses smaller channels out_channels=frame_shape[0], frame_shape=frame_shape, )
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]: """Full encode-decode forward pass. Args: x: Input images (B, C, H, W) Returns: reconstruction: Reconstructed images indices: Token indices (B, H', W') loss_dict: Dictionary with loss components """ z_q, indices, vq_loss = self.encoder(x) # Decode (use detached z_q to stop gradient through decoder for VQ loss) self.decoder( z_q.detach() + z_q - z_q ) # identity with gradient stop for z_q part # Actually, we want gradients to flow through reconstruction path reconstruction_st = self.decoder(z_q) # Compute reconstruction loss recon_loss = F.l1_loss(reconstruction_st, x) # Combine losses loss = recon_loss + vq_loss["vq_loss"] loss_dict = { "reconstruction": recon_loss, "vq": vq_loss["vq_loss"], "perplexity": vq_loss["perplexity"], "total": loss, } return reconstruction_st, indices, loss_dict
[docs] def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode to token indices.""" return self.encoder.encode_to_indices(x)
[docs] def decode(self, indices: torch.Tensor) -> torch.Tensor: """Decode token indices to images.""" embeddings = self.encoder.decode_from_indices(indices) return self.decoder(embeddings)