Source code for world_models.vision.iris_encoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

from world_models.vision.vq_layer import VectorQuantizerEMA


[docs] class IRISEncoder(nn.Module): """CNN Encoder for IRIS discrete autoencoder. Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer. Architecture: - 4 convolutional layers with residual blocks - Self-attention at 8x8 and 16x16 resolutions - Vector quantization to produce discrete tokens """ def __init__( self, vocab_size: int = 512, tokens_per_frame: int = 16, embedding_dim: int = 512, in_channels: int = 3, base_channels: int = 64, num_residual_blocks: int = 2, frame_shape: Tuple[int, int, int] = (3, 64, 64), ): super().__init__() self.vocab_size = vocab_size self.tokens_per_frame = tokens_per_frame self.embedding_dim = embedding_dim # Compute expected spatial dimensions after conv layers # 64 -> 32 -> 16 -> 8 -> 4 with 4 conv layers self.expected_spatial_size = 4 # After 4 stride-2 convs, 64/16 = 4 # Number of tokens per dimension (height x width) # tokens_per_frame should equal expected_spatial_size^2 # 16 = 4x4 # CNN encoder body self.conv_blocks = nn.ModuleList() in_ch = in_channels channels = [ base_channels, base_channels * 2, base_channels * 4, base_channels * 8, ] for i, out_ch in enumerate(channels): self.conv_blocks.append( nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1), nn.ReLU(), ) ) in_ch = out_ch # Residual blocks self.residual_blocks = nn.Sequential( *[ResidualBlock(channels[-1]) for _ in range(num_residual_blocks)] ) # Self-attention at intermediate resolutions # Apply attention at 8x8 and 16x16 self.attention_8 = SelfAttentionBlock(channels[1]) # 16x16 self.attention_4 = SelfAttentionBlock(channels[2]) # 8x8 # Project to embedding dimension self.projection = nn.Conv2d(channels[-1], embedding_dim, 1) # Vector quantizer self.quantizer = VectorQuantizerEMA( vocab_size=vocab_size, embedding_dim=embedding_dim, commitment_weight=0.25, )
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]: """Encode images to discrete tokens. Args: x: Input images (B, C, H, W) - should be 64x64 Returns: z_q: Quantized tokens (B, C, H', W') indices: Token indices (B, H', W') vq_loss: Dictionary with VQ loss components """ # CNN encoding h = x # Conv block 1 -> 32x32 h = self.conv_blocks[0](h) # Conv block 2 -> 16x16, apply attention at this resolution h = self.conv_blocks[1](h) h = self.attention_8(h) # Conv block 3 -> 8x8, apply attention at this resolution h = self.conv_blocks[2](h) h = self.attention_4(h) # Conv block 4 -> 4x4 h = self.conv_blocks[3](h) # Residual blocks h = self.residual_blocks(h) # Project to embedding dimension h = self.projection(h) # Quantize z_q, indices, vq_loss = self.quantizer(h) return z_q, indices, vq_loss
[docs] def encode_to_indices(self, x: torch.Tensor) -> torch.Tensor: """Encode directly to token indices (for world model).""" with torch.no_grad(): _, indices, _ = self.forward(x) return indices
[docs] def decode_from_indices(self, indices: torch.Tensor) -> torch.Tensor: """Decode token indices to embeddings (for decoder).""" return self.quantizer.decode_indices(indices)
[docs] class ResidualBlock(nn.Module): """Residual block for encoder.""" def __init__(self, channels: int): super().__init__() self.block = nn.Sequential( nn.ReLU(), nn.Conv2d(channels, channels, 3, padding=1), nn.ReLU(), nn.Conv2d(channels, channels, 1), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.block(x)
[docs] class SelfAttentionBlock(nn.Module): """Self-attention block for encoder. Applies spatial self-attention to capture long-range dependencies. """ def __init__(self, channels: int): super().__init__() self.query = nn.Conv2d(channels, channels, 1) self.key = nn.Conv2d(channels, channels, 1) self.value = nn.Conv2d(channels, channels, 1) self.gamma = nn.Parameter(torch.zeros(1))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape # Compute Q, K, V q = self.query(x).reshape(B, C, H * W).permute(0, 2, 1) # (B, HW, C) k = self.key(x).reshape(B, C, H * W).permute(0, 2, 1) # (B, HW, C) v = self.value(x).reshape(B, C, H * W).permute(0, 2, 1) # (B, HW, C) # Attention scores attn = torch.bmm(q, k.transpose(1, 2)) / (C**0.5) attn = F.softmax(attn, dim=-1) # Apply attention out = torch.bmm(attn, v) # (B, HW, C) out = out.permute(0, 2, 1).reshape(B, C, H, W) # Residual connection with learned weight return x + self.gamma * out