Source code for world_models.models.latent_action_model

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict, Literal

from world_models.vision.vq_layer import VectorQuantizer
from world_models.blocks.st_transformer import STTransformer


[docs] class LatentActionModel(nn.Module): """Latent Action Model (LAM) for unsupervised action learning. Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction. Based on Genie paper - learns actions without action labels from Internet videos. Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t """ def __init__( self, num_frames: int = 16, image_size: int = 64, in_channels: int = 3, encoder_dim: int = 256, decoder_dim: int = 512, encoder_depth: int = 4, decoder_depth: int = 4, num_heads: int = 8, patch_size: int = 16, vocab_size: int = 8, embedding_dim: int = 32, commitment_weight: float = 1.0, action_pooling: Literal["mean", "windowed_attention"] = "mean", window_attention_heads: int = 1, ): super().__init__() if action_pooling not in {"mean", "windowed_attention"}: raise ValueError( "action_pooling must be either 'mean' or 'windowed_attention'" ) if action_pooling == "windowed_attention": if window_attention_heads < 1: raise ValueError("window_attention_heads must be at least 1") if embedding_dim % window_attention_heads != 0: raise ValueError( "embedding_dim must be divisible by window_attention_heads" ) self.num_frames = num_frames self.image_size = image_size self.patch_size = patch_size self.vocab_size = vocab_size self.encoder_dim = encoder_dim self.decoder_dim = decoder_dim self.embedding_dim = embedding_dim self.in_channels = in_channels self.action_pooling = action_pooling num_patches = (image_size // patch_size) ** 2 # ===== ENCODER ===== # Takes x1:t and x_t+1, outputs latent actions at each timestep # Uses ST-Transformer as per Genie paper # Input is T+1 frames (T previous + 1 next), output is T latent actions self.patch_embed = nn.Conv2d( in_channels, encoder_dim, kernel_size=patch_size, stride=patch_size, ) # Encoder processes T frames (x1:t + x_{t+1} combined), need T positions # Support up to max_seq_len frames (should be set to max expected frames) self.max_seq_len = max(num_frames + 1, 32) # At least 32 for flexibility self.encoder_pos_embed = nn.Parameter( torch.zeros(1, self.max_seq_len * num_patches, encoder_dim) ) nn.init.trunc_normal_(self.encoder_pos_embed, std=0.02) # ST-Transformer for encoder (per paper - uses spatiotemporal attention) self.encoder = STTransformer( num_frames=self.max_seq_len, num_patches_per_frame=num_patches, dim=encoder_dim, depth=encoder_depth, num_heads=num_heads, drop_rate=0.0, attn_drop_rate=0.0, ) self.to_vq_embedding = nn.Linear(encoder_dim, embedding_dim) self.window_attention = ( nn.MultiheadAttention( embed_dim=embedding_dim, num_heads=window_attention_heads, batch_first=True, ) if action_pooling == "windowed_attention" else None ) self.vq = VectorQuantizer( vocab_size=vocab_size, embedding_dim=embedding_dim, commitment_weight=commitment_weight, ) # ===== ACTION EMBEDDING ===== self.action_embedding = nn.Embedding(vocab_size, encoder_dim) # ===== DECODER ===== # Takes x1:t-1 and a1:t-1, predicts x_t # Uses masked frames (all but first) to force action usage # Uses ST-Transformer as per Genie paper self.decoder_patch_embed = nn.Conv2d( in_channels, decoder_dim, kernel_size=patch_size, stride=patch_size, ) # Decoder processes T frames (x1:t), outputs T predictions self.decoder_pos_embed = nn.Parameter( torch.zeros(1, self.max_seq_len * num_patches, decoder_dim) ) nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02) # Decoder uses action embeddings as additive conditioning self.decoder_action_proj = nn.Linear(embedding_dim, decoder_dim) # ST-Transformer for decoder (per paper - uses spatiotemporal attention) self.decoder = STTransformer( num_frames=num_frames, num_patches_per_frame=num_patches, dim=decoder_dim, depth=decoder_depth, num_heads=num_heads, drop_rate=0.0, attn_drop_rate=0.0, ) self.decoder_pos_embed = nn.Parameter( torch.zeros(1, self.max_seq_len * num_patches, decoder_dim) ) nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02) # Decoder uses action embeddings as additive conditioning self.decoder_action_proj = nn.Linear(embedding_dim, decoder_dim) # ST-Transformer for decoder (per paper - uses spatiotemporal attention) self.decoder = STTransformer( num_frames=self.max_seq_len, num_patches_per_frame=num_patches, dim=decoder_dim, depth=decoder_depth, num_heads=num_heads, drop_rate=0.0, attn_drop_rate=0.0, ) # Output projection to reconstruct pixels self.decoder_out = nn.Linear(decoder_dim, in_channels * patch_size * patch_size) nn.init.xavier_uniform_(self.decoder_out.weight) nn.init.zeros_(self.decoder_out.bias) def _pool_windowed_attention( self, x_t_embed: torch.Tensor, x_next_embed: torch.Tensor ) -> torch.Tensor: """Apply length-2 temporal attention, then mean-pool the attended tokens. Args: x_t_embed: Current timestep embeddings with shape (B, N, embedding_dim). x_next_embed: Next timestep embeddings with shape (B, N, embedding_dim). Returns: Pooled action embeddings with shape (B, embedding_dim). """ if self.window_attention is None: raise RuntimeError("window attention module is not initialized") B, N, E = x_t_embed.shape windows = torch.stack([x_t_embed, x_next_embed], dim=2) windows = windows.reshape(B * N, 2, E) attended_windows, _ = self.window_attention( windows, windows, windows, need_weights=False, ) attended_windows = attended_windows.reshape(B, N, 2, E) return attended_windows.mean(dim=(1, 2))
[docs] def encode( self, x_prev: torch.Tensor, x_next: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Encode frames to latent actions. Args: x_prev: Previous frames (B, C, T, H, W) x_next: Next frame (B, C, H, W) Returns: latent_actions: Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim) """ B, C, T, H, W = x_prev.shape # Concatenate all frames (x1:t and x_{t+1}) as per paper x_all = torch.cat([x_prev, x_next.unsqueeze(2)], dim=2) # (B, C, T+1, H, W) # Patch embed x = x_all.permute(0, 2, 1, 3, 4).reshape(B * (T + 1), C, H, W) x = self.patch_embed(x) _, C_enc, H_enc, W_enc = x.shape num_patches = H_enc * W_enc # Reshape for ST-Transformer: (B, T+1, N, C) then flatten to (B, (T+1)*N, C) x = x.reshape(B, (T + 1), num_patches, C_enc) # Add positional embeddings x = x.reshape(B, (T + 1) * num_patches, C_enc) x = x + self.encoder_pos_embed[:, : (T + 1) * num_patches, :] # ST-Transformer forward (expects B, T*N, C format) x = self.encoder(x) # (B, (T+1)*N, C) # Reshape back to (B, T+1, N, C) and take first T frames x = x.reshape(B, T + 1, num_patches, C_enc) x = x[:, :T, :, :] # (B, T, N, C) # Quantize each timestep z_all = [] indices_all = [] for t in range(T): x_t = x[:, t, :, :] # (B, N, C) x_t_embed = self.to_vq_embedding(x_t) # (B, N, embedding_dim) if self.action_pooling == "windowed_attention": next_t = min(t + 1, T - 1) x_next_t = x[:, next_t, :, :] # (B, N, C) x_next_t_embed = self.to_vq_embedding( x_next_t ) # (B, N, embedding_dim) x_t_pooled = self._pool_windowed_attention( x_t_embed, x_next_t_embed ) # (B, embedding_dim) else: # Pool across spatial dimension x_t_pooled = x_t_embed.mean(dim=1) # (B, embedding_dim) z_q_t, indices_t, _ = self.vq(x_t_pooled) z_all.append(z_q_t) indices_all.append(indices_t) z_q = torch.stack(z_all, dim=1) # (B, T, embedding_dim) indices = torch.stack(indices_all, dim=1) # (B, T, 1) latent_actions = indices.squeeze(-1) # (B, T) return latent_actions, z_q
[docs] def decode( self, x_prev: torch.Tensor, z_q: torch.Tensor, ) -> torch.Tensor: """Decode latent actions to predict next frame. Args: x_prev: Previous frames (B, C, T, H, W) - will mask all but first z_q: Quantized action embeddings (B, T-1, embedding_dim) Returns: predicted_next_frame: (B, C, H, W) """ B, C, T, H, W = x_prev.shape # Mask all frames except the first - forces decoder to use actions x_masked = x_prev.clone() for t in range(1, T): x_masked[:, :, t, :, :] = 0 # Patch embed x = x_masked.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) x = self.decoder_patch_embed(x) _, C_dec, H_dec, W_dec = x.shape num_patches = H_dec * W_dec # Reshape for ST-Transformer: (B, T, N, C) then flatten to (B, T*N, C) x = x.reshape(B, T, num_patches, C_dec) # Project action embeddings to decoder dimension action_cond = self.decoder_action_proj(z_q) # (B, T-1, decoder_dim) # Add action embeddings to frames (align T-1 actions with T-1 future frames) # Frames 1 to T-1 get action conditioning, frame 0 is the prompt action_expanded = torch.zeros(B, T, num_patches, C_dec, device=x.device) action_expanded[:, 1:, :, :] = action_cond.unsqueeze(2) # Combine frame embeddings with action conditioning x = x + action_expanded # Add positional embeddings and flatten for ST-Transformer x = x.reshape(B, T * num_patches, C_dec) x = x + self.decoder_pos_embed[:, : T * num_patches, :] # ST-Transformer forward x = self.decoder(x) # (B, T*N, C) # Reshape back to (B, T, N, C) x = x.reshape(B, T, num_patches, C_dec) # Use only last frame (predicted next frame) x = x[:, -1, :, :] # (B, N, C) # Project to pixel space x = self.decoder_out(x) # (B, N, C*patch_size*patch_size) # Reshape to image x = x.reshape(B, H_dec, W_dec, C, self.patch_size, self.patch_size) x = x.permute(0, 3, 1, 4, 2, 5) # (B, C, H_dec, patch_size, W_dec, patch_size) x = x.reshape(B, C, H_dec * self.patch_size, W_dec * self.patch_size) # Resize to original image size if needed if x.shape[2] != H or x.shape[3] != W: x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False) return torch.sigmoid(x)
[docs] def forward( self, x_prev: torch.Tensor, x_next: torch.Tensor, ) -> Dict[str, torch.Tensor]: """Full forward pass: encode to get actions, decode to reconstruct. Args: x_prev: Previous frames (B, C, T, H, W) x_next: Next frame (B, C, H, W) Returns: Dictionary with losses and outputs """ # Encode to get latent actions latent_actions, z_q = self.encode(x_prev, x_next) # Use actions from timestep 1 to T-1 (T-1 actions for T-1 to predict T) # latent_actions has shape (B, T) from encoding T frames # We need T-1 actions to predict T-1 future frames num_actions = z_q.shape[1] - 1 # T-1 z_q_for_decode = z_q[:, :num_actions, :] if num_actions > 0 else z_q # Decode to reconstruct next frame # decode expects x_prev with T frames and z_q with T-1 actions reconstructed = self.decode(x_prev, z_q_for_decode) # Compute reconstruction loss recon_loss = F.mse_loss(reconstructed, x_next) # Get VQ loss from quantizer B = x_prev.shape[0] T = x_prev.shape[2] z_q_reshaped = z_q.reshape(B * T, -1) # (B*T, embedding_dim) _, _, vq_info = self.vq(z_q_reshaped) vq_loss = vq_info["vq_loss"] # Variance loss: encourage batch-wise variance in action embeddings # This prevents the encoder from collapsing to a single action z_for_variance = z_q.detach() # Detach to only optimize encoder action_variance = z_for_variance.var(dim=0).mean() variance_loss = ( -action_variance ) # Negative because we want to maximize variance return { "latent_actions": latent_actions, "z_q": z_q, "reconstructed": reconstructed, "recon_loss": recon_loss, "vq_loss": vq_loss, "variance_loss": variance_loss, }
[docs] def create_latent_action_model( num_frames: int = 16, image_size: int = 64, in_channels: int = 3, encoder_dim: int = 256, decoder_dim: int = 512, encoder_depth: int = 4, decoder_depth: int = 4, num_heads: int = 8, patch_size: int = 16, vocab_size: int = 8, embedding_dim: int = 32, action_pooling: Literal["mean", "windowed_attention"] = "mean", window_attention_heads: int = 1, ) -> LatentActionModel: """Factory function to create a Latent Action Model.""" return LatentActionModel( num_frames=num_frames, image_size=image_size, in_channels=in_channels, encoder_dim=encoder_dim, decoder_dim=decoder_dim, encoder_depth=encoder_depth, decoder_depth=decoder_depth, num_heads=num_heads, patch_size=patch_size, vocab_size=vocab_size, embedding_dim=embedding_dim, action_pooling=action_pooling, window_attention_heads=window_attention_heads, )