Source code for world_models.blocks.st_transformer

import torch
import torch.nn as nn
from typing import Optional


[docs] class STSpatialAttention(nn.Module): """Spatial attention layer for spatiotemporal transformer. Processes video tokens by attending over spatial positions (H*W) within each time step independently. Captures within-frame spatial relationships. Input: (B, T, N, C) - B batches, T time steps, N spatial positions (H*W), C channels Output: (B, T, N, C) - Same shape, spatially attended features Architecture: QKV projection: Linear(dim, dim*3) Reshape to multi-head attention format Attention: softmax(Q @ K^T / sqrt(d_k)) @ V Output projection Usage in ST-Transformer: Applied to video tokens of shape (B, T, N, C) to capture within-frame spatial structure (e.g., object positions). """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_scale: Optional[float] = None, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # QK Normalization (as per Genie paper - improves stability at large scale) # Always apply LayerNorm to Q and K per-head (no runtime toggle). self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=False, eps=1e-6) self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=False, eps=1e-6) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, T, N, C) where T is temporal dim, N is spatial dim (H*W) Returns: (B, T, N, C) """ B, T, N, C = x.shape qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads) qkv = qkv.permute(3, 0, 4, 1, 2, 5) # (3, B, heads, T, N, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # Apply QK normalization (as per Genie paper) q = self.q_norm(q) k = self.k_norm(k) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(2, 3).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) return x
[docs] class STTemporalAttention(nn.Module): """Temporal attention layer with causal masking for spatiotemporal transformer. Processes video tokens by attending over time steps (T) across all spatial positions. Uses causal masking to ensure each frame only attends to previous frames (important for autoregressive video generation). Input: (B, T, N, C) - B batches, T time steps, N spatial positions, C channels Output: (B, T, N, C) - Same shape, temporally attended features Key Feature: Causal masking - Frame t can only attend to frames 0...t-1 - Prevents information leakage from future frames - Essential for autoregressive video generation models Usage in Genie VideoTokenizer: Applied after STSpatialAttention to model temporal dynamics. The causal mask ensures generation is autoregressive. """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_scale: Optional[float] = None, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # QK Normalization (as per Genie paper - improves stability at large scale) # Always apply LayerNorm to Q and K per-head (no runtime toggle). self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=False, eps=1e-6) self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=False, eps=1e-6) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
[docs] def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: """ Args: x: (B, T, N, C) where T is temporal dim, N is spatial dim (H*W) causal: whether to apply causal masking Returns: (B, T, N, C) """ B, T, N, C = x.shape qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads) qkv = qkv.permute(3, 0, 4, 2, 1, 5) # (3, B, heads, N, T, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # Apply QK normalization (as per Genie paper) q = self.q_norm(q) k = self.k_norm(k) attn = (q @ k.transpose(-2, -1)) * self.scale if causal: mask = torch.triu( torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1 ) attn = attn.masked_fill(mask, float("-inf")) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 3).reshape(B, T, N, C) x = self.proj(x) x = self.proj_drop(x) return x
[docs] class STMLP(nn.Module): """MLP for ST-Transformer block.""" def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: type[nn.Module] = nn.GELU, drop: float = 0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs] class STTransformerBlock(nn.Module): """Combined spatiotemporal transformer block with interleaved attention. A single block applies: 1. Spatial attention (within each time frame) 2. Temporal attention (across frames with causal mask) 3. MLP projection The order is: x -> + SpatialAttn -> + TemporalAttn -> + MLP -> x This interleaved design captures both spatial structure and temporal dynamics efficiently, used in Genie's VideoTokenizer and DynamicsModel. Args: dim: Feature dimension (must match patch embedding dimension) num_heads: Number of attention heads mlp_ratio: MLP hidden dim = dim * mlp_ratio drop, attn_drop: Dropout rates drop_path: Stochastic depth rate for drop path regularization norm_layer: Normalization layer class (default: nn.LayerNorm) Usage in Genie: # VideoTokenizer encoder (12 layers) encoder = STTransformer( num_frames=16, num_patches_per_frame=256, # 16x16 for 64x64 images with patch_size=4 dim=512, depth=12, num_heads=16 ) encoded = encoder(tokens) # (B, T*N, C) # Dynamics model decoder (24 layers) decoder = STTransformer( num_frames=16, num_patches_per_frame=256, dim=1024, depth=24, num_heads=16 ) decoded = decoder(tokens) """ def __init__( self, dim: int, num_heads: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_scale: Optional[float] = None, drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, act_layer: type[nn.Module] = nn.GELU, norm_layer: type[nn.Module] = nn.LayerNorm, ): super().__init__() self.norm1_spatial = norm_layer(dim) self.norm1_temporal = norm_layer(dim) self.attn_spatial = STSpatialAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.attn_temporal = STTemporalAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) mlp_hidden_dim = int(dim * mlp_ratio) self.norm2 = norm_layer(dim) self.mlp = STMLP( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, T, N, C) or (B, T*H*W, C) Returns: Same shape as input """ # Handle both (B, T*N, C) and (B, T, N, C) inputs if x.dim() == 3: B, T_N, C = x.shape # Infer T and N - we assume T=16 and N=H*W T = 16 # Default sequence length from paper N = T_N // T x = x.reshape(B, T, N, C) B, T, N, C = x.shape # Spatial attention (within each time step) x = x + self.drop_path(self.attn_spatial(self.norm1_spatial(x))) # Temporal attention (across time steps, with causal mask) x = x + self.drop_path(self.attn_temporal(self.norm1_temporal(x))) # MLP (single FFW after both spatial and temporal, as per paper) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
[docs] class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample.""" def __init__(self, drop_prob: float = 0.0): super().__init__() self.drop_prob = drop_prob
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() output = x.div(keep_prob) * random_tensor return output
[docs] class STTransformer(nn.Module): """Spatiotemporal Transformer for video modeling. Contains L spatiotemporal blocks with interleaved spatial and temporal attention. """ def __init__( self, num_frames: int = 16, num_patches_per_frame: int = 256, dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: Optional[float] = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: type[nn.Module] = nn.LayerNorm, ): super().__init__() self.num_frames = num_frames self.num_patches_per_frame = num_patches_per_frame dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks = nn.ModuleList( [ STTransformerBlock( dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, ) for i in range(depth) ] ) self.norm = norm_layer(dim)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, T*N, C) where T is num_frames, N is num_patches_per_frame Returns: (B, T*N, C) """ B, T_N, C = x.shape T = T_N // self.num_patches_per_frame N = self.num_patches_per_frame # Reshape to (B, T, N, C) for ST-attention x = x.reshape(B, T, N, C) for blk in self.blocks: x = blk(x) x = self.norm(x) # Reshape back to (B, T*N, C) x = x.reshape(B, T * N, C) return x
[docs] def create_st_transformer( num_frames: int = 16, patch_size: int = 4, img_size: int = 64, dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, ) -> STTransformer: """Factory function to create an ST-Transformer.""" num_patches_per_frame = (img_size // patch_size) ** 2 return STTransformer( num_frames=num_frames, num_patches_per_frame=num_patches_per_frame, dim=dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, )