Source code for world_models.models.iris_transformer

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple


[docs] class IRISTransformer(nn.Module): """Autoregressive Transformer for world modeling. Models the dynamics of the environment by predicting: - Next frame tokens (transition model) - Rewards - Episode termination The Transformer operates on sequences of interleaved frame tokens and actions. """ def __init__( self, vocab_size: int = 512, tokens_per_frame: int = 16, action_size: int = 18, # Number of Atari actions embed_dim: int = 256, num_layers: int = 10, num_heads: int = 4, dropout: float = 0.1, ): super().__init__() self.vocab_size = vocab_size self.tokens_per_frame = tokens_per_frame self.action_size = action_size self.embed_dim = embed_dim self.num_layers = num_layers self.num_heads = num_heads # Token embeddings self.token_embedding = nn.Embedding(vocab_size, embed_dim) self.action_embedding = nn.Embedding(action_size, embed_dim) # Positional embeddings # Max sequence length: (tokens_per_frame + 1) * timesteps # 16 tokens + 1 action per timestep = 17 tokens/timestep max_tokens = tokens_per_frame + 1 # tokens + action max_seq_len = max_tokens * 50 # Support up to 50 timesteps self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, embed_dim) * 0.02) # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4, dropout=dropout, activation="gelu", batch_first=True, norm_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) # Output heads self.layer_norm = nn.LayerNorm(embed_dim) # Token prediction head (for next frame tokens) # Predicts each token of the next frame self.token_head = nn.Linear(embed_dim, vocab_size) # Reward prediction head self.reward_head = nn.Linear(embed_dim, 1) # Termination prediction head self.termination_head = nn.Linear( embed_dim, 2 ) # Binary: 0=continue, 1=terminal self._init_weights() def _init_weights(self): """Initialize weights with proper scaling.""" nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.action_embedding.weight, std=0.02) # Apply special initialization to output heads nn.init.zeros_(self.token_head.bias) nn.init.zeros_(self.reward_head.bias) nn.init.zeros_(self.termination_head.bias) def _build_sequence( self, tokens: torch.Tensor, actions: torch.Tensor, ) -> torch.Tensor: """Build the interleaved token-action sequence. Args: tokens: Frame tokens (B, T, K) actions: Actions (B, T) Returns: Sequence ready for transformer (B, T*(K+1), embed_dim) """ B, T, K = tokens.shape tokens_flat = tokens.reshape(B, T * K) token_embeds = self.token_embedding(tokens_flat) action_embeds = self.action_embedding(actions) token_embeds = token_embeds.reshape(B, T, K, self.embed_dim) action_embeds_expanded = action_embeds.unsqueeze(2) sequence = torch.cat([token_embeds, action_embeds_expanded], dim=2) sequence = sequence.reshape(B, T * (K + 1), self.embed_dim) sequence = sequence + self.pos_embedding[:, : T * (K + 1), :] return sequence
[docs] def forward( self, tokens: torch.Tensor, # (B, T, K) - frame tokens actions: torch.Tensor, # (B, T) - actions mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass through the Transformer world model. Args: tokens: Frame tokens (B, T, K) where T is timesteps actions: Actions (B, T) mask: Optional attention mask Returns: token_logits: Next token predictions (B, T, K, vocab_size) rewards: Predicted rewards (B, T) terminations: Predicted terminations (B, T, 2) """ B, T, K = tokens.shape # Flatten tokens: (B, T, K) -> (B, T*K) tokens_flat = tokens.reshape(B, T * K) # Embed tokens token_embeds = self.token_embedding(tokens_flat) # (B, T*K, embed_dim) # Embed actions: (B, T) -> (B, T, embed_dim) action_embeds = self.action_embedding(actions) # (B, T, embed_dim) # Reshape token embeddings to (B, T, K, embed_dim) token_embeds = token_embeds.reshape(B, T, K, self.embed_dim) # Interleave: for each timestep, concat token embeddings with action embedding # sequence: [tokens_t, action_t] for each t # Result: (B, T, K+1, embed_dim) action_embeds_expanded = action_embeds.unsqueeze(2) # (B, T, 1, embed_dim) sequence = torch.cat( [token_embeds, action_embeds_expanded], dim=2 ) # (B, T, K+1, embed_dim) # Flatten: (B, T, K+1, embed_dim) -> (B, T*(K+1), embed_dim) sequence = sequence.reshape(B, T * (K + 1), self.embed_dim) # Create position ids (torch.arange(T * (K + 1), device=tokens.device).unsqueeze(0).expand(B, -1)) # Add positional embeddings sequence = sequence + self.pos_embedding[:, : T * (K + 1), :] # Apply transformer hidden = self.transformer(sequence, mask=mask) hidden = self.layer_norm(hidden) # Reshape hidden states back to per-timestep structure # Each timestep has K tokens + 1 action = K+1 positions # hidden[:, i*(K+1):i*(K+1)+K, :] = frame token predictions for step i # hidden[:, i*(K+1)+K, :] = action token predictions for step i # Reshape to (B, T, K+1, embed_dim) hidden = hidden.reshape(B, T, K + 1, self.embed_dim) # Extract predictions for each timestep token_hidden = hidden[:, :, :K, :] # (B, T, K, embed_dim) action_hidden = hidden[:, :, K, :] # (B, T, embed_dim) # Token predictions (next frame) token_logits = self.token_head(token_hidden) # (B, T, K, vocab_size) # Reward predictions (from action position) rewards = self.reward_head(action_hidden).squeeze(-1) # (B, T) # Termination predictions terminations = self.termination_head(action_hidden) # (B, T, 2) return token_logits, rewards, terminations
[docs] def predict_next_tokens( self, tokens: torch.Tensor, actions: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predict the next frame tokens autoregressively. Used during imagination rollouts. Args: tokens: Current frame tokens (B, K) actions: Actions taken (B,) Returns: token_logits: Next frame token predictions (B, K, vocab_size) action_hidden: Hidden states for reward prediction (B, embed_dim) """ # Handle token shapes: (B, H, W) -> (B, K) -> (B, 1, K) if tokens.dim() == 3: # tokens is (B, H, W) grid of tokens, flatten to (B, K) B_grid, H, W = tokens.shape tokens = tokens.reshape(B_grid, H * W) if tokens.dim() == 2: tokens = tokens.unsqueeze(1) # (B, K) -> (B, 1, K) if actions.dim() == 1: actions = actions.unsqueeze(1) # (B,) -> (B, 1) token_logits, _, _ = self.forward(tokens, actions) # Get action hidden states for reward prediction B, T, K, embed_dim = token_logits.shape hidden = self.layer_norm( self.transformer(self._build_sequence(tokens, actions), mask=None) ) hidden = hidden.reshape(B, T, K + 1, self.embed_dim) action_hidden = hidden[:, -1, K, :] # (B, embed_dim) return ( token_logits[:, -1, :, :], action_hidden, ) # Return last timestep predictions
[docs] def sample_next_tokens( self, tokens: torch.Tensor, actions: torch.Tensor, temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Sample next tokens from the distribution. Args: tokens: Current frame tokens (B, K) actions: Actions taken (B,) temperature: Sampling temperature (higher = more random) Returns: sampled_tokens: Sampled token indices (B, K) log_probs: Log probabilities of sampled tokens (B, K) """ token_logits, _ = self.predict_next_tokens(tokens, actions) # Apply temperature token_logits = token_logits / temperature # Sample from categorical probs = F.softmax(token_logits, dim=-1) sampled_indices = torch.multinomial(probs.reshape(-1, self.vocab_size), 1) sampled_indices = sampled_indices.reshape_as(tokens) # Compute log probabilities log_probs = F.log_softmax(token_logits, dim=-1) log_probs = torch.gather(log_probs, -1, sampled_indices.unsqueeze(-1)).squeeze( -1 ) return sampled_indices, log_probs
[docs] class IRISWorldModel(nn.Module): """Complete IRIS World Model combining autoencoder and transformer. This is the core component that learns environment dynamics entirely in the "imaginary" latent space. """ def __init__( self, encoder: nn.Module, decoder: nn.Module, transformer: IRISTransformer, ): super().__init__() self.encoder = encoder self.decoder = decoder self.transformer = transformer
[docs] def forward( self, observations: torch.Tensor, # (B, T+1, C, H, W) actions: torch.Tensor, # (B, T) ) -> Tuple[dict, dict]: """Full world model forward pass. Args: observations: Image sequence (B, T+1, C, H, W) actions: Actions (B, T) Returns: predictions: Dictionary with predicted tokens, rewards, terminations losses: Dictionary with loss components """ B, T_plus_1, C, H, W = observations.shape T = T_plus_1 - 1 # Encode each frame to tokens tokens_list = [] for t in range(T_plus_1): obs_t = observations[:, t] # (B, C, H, W) _, indices_t, _ = self.encoder(obs_t) tokens_list.append(indices_t) # Stack tokens: (B, T+1, K) tokens = torch.stack(tokens_list, dim=1) # Get transformer predictions token_logits, rewards_pred, terminations_pred = self.transformer( tokens[:, :-1], # (B, T, K) - all frames except last actions, # (B, T) ) # Decode predictions to images (for visualization) # For each timestep, decode the predicted tokens decoded_frames = [] for t in range(T): next_tokens_pred = token_logits[:, t, :, :].argmax(dim=-1) # Greedy decoded_frames.append(self.decoder.decode_from_embeddings(next_tokens_pred)) decoded_frames = torch.stack(decoded_frames, dim=1) if decoded_frames else None # Get actual next tokens for loss computation next_tokens = tokens[:, 1:] # (B, T, K) # Compute losses token_loss = F.cross_entropy( token_logits.reshape(-1, self.transformer.vocab_size), next_tokens.reshape(-1), reduction="mean", ) # Reward and termination losses would be computed with actual labels # (These are computed in the training loop) predictions = { "token_logits": token_logits, "rewards": rewards_pred, "terminations": terminations_pred, "decoded_frames": decoded_frames, } losses = { "token_loss": token_loss, } return predictions, losses
[docs] def imagine( self, initial_tokens: torch.Tensor, # (B, K) policy: nn.Module, horizon: int = 20, temperature: float = 1.0, ) -> dict: """Generate imagined trajectories. Args: initial_tokens: Initial frame tokens (B, K) policy: Policy network to sample actions horizon: Number of steps to imagine temperature: Sampling temperature for token prediction Returns: imagined: Dictionary with imagined trajectories """ # Lists to store trajectory tokens_history = [initial_tokens] actions_history = [] rewards_history = [] terminations_history = [] # Get initial reconstruction for policy input current_tokens = initial_tokens for step in range(horizon): # Get action from policy (using decoded frame) with torch.no_grad(): decoded_frame = self.decoder.decode_from_embeddings(current_tokens) action = policy(decoded_frame) actions_history.append(action) # Predict next tokens sampled_tokens, log_probs = self.transformer.sample_next_tokens( current_tokens, action.squeeze(-1) if action.dim() > 1 else action, temperature, ) # Get reward and termination predictions with torch.no_grad(): token_logits, action_hidden = self.transformer.predict_next_tokens( current_tokens, action ) reward = self.transformer.reward_head(action_hidden).mean() termination_logits = self.transformer.termination_head(action_hidden) termination = torch.softmax(termination_logits, dim=-1)[:, 1] tokens_history.append(sampled_tokens) rewards_history.append(reward) terminations_history.append(termination) # Update current tokens current_tokens = sampled_tokens # Early stopping if terminal if termination.mean() > 0.5: break return { "tokens": torch.stack(tokens_history, dim=1), # (B, H+1, K) "actions": torch.stack(actions_history, dim=1) if actions_history else None, "rewards": torch.stack(rewards_history, dim=1) if rewards_history else None, "terminations": ( torch.stack(terminations_history, dim=1) if terminations_history else None ), }