Source code for world_models.vision.video_tokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Union

from world_models.vision.vq_layer import VectorQuantizer, VectorQuantizerEMA
from world_models.blocks.st_transformer import STTransformer


[docs] class VideoTokenizer(nn.Module): """Video Tokenizer using VQ-VAE with Spatiotemporal Transformer. This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel. The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos. Architecture: 1. Patch Embedding: Convert (B, C, T, H, W) video to patch tokens 2. Encoder ST-Transformer: Process spatial-temporal patches 3. Vector Quantization: Discretize continuous embeddings to codebook entries 4. Decoder ST-Transformer: Reconstruct video from quantized tokens 5. Patch Unembedding: Convert tokens back to video frames Key Features: - Causal processing: Each frame's encoding only uses previous frames - Discrete tokens: Enables autoregressive prediction with latent actions - Memory efficient: Uses ST-Transformer instead of full ViT to reduce O(n²) complexity Usage with Genie: tokenizer = VideoTokenizer( num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32 ) reconstructed, indices, loss_dict = tokenizer(video_frames) # For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices) Training: The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings (encourages learning useful codes) - Commitment loss: Penalizes encoder outputs drifting from codebook Reference: Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391 """ def __init__( self, num_frames: int = 16, image_size: int = 64, in_channels: int = 3, encoder_dim: int = 512, decoder_dim: int = 1024, encoder_depth: int = 12, decoder_depth: int = 20, num_heads: int = 16, patch_size: int = 4, vocab_size: int = 1024, embedding_dim: int = 32, commitment_weight: float = 0.25, use_ema: bool = False, ema_decay: float = 0.99, ): super().__init__() self.num_frames = num_frames self.image_size = image_size self.patch_size = patch_size self.vocab_size = vocab_size self.embedding_dim = embedding_dim num_patches = (image_size // patch_size) ** 2 self.encoder_dim = encoder_dim self.decoder_dim = decoder_dim self.patch_embed = nn.Conv2d( in_channels, encoder_dim, kernel_size=patch_size, stride=patch_size, ) self.encoder_pos_embed = nn.Parameter( torch.zeros(1, num_frames * num_patches, encoder_dim) ) nn.init.trunc_normal_(self.encoder_pos_embed, std=0.02) self.encoder = STTransformer( num_frames=num_frames, num_patches_per_frame=num_patches, dim=encoder_dim, depth=encoder_depth, num_heads=num_heads, drop_rate=0.0, attn_drop_rate=0.0, ) self.to_vq_embedding = nn.Linear(encoder_dim, embedding_dim) # vq can be either the EMA variant or the plain VectorQuantizer self.vq: Union[VectorQuantizer, VectorQuantizerEMA] if use_ema: self.vq = VectorQuantizerEMA( vocab_size=vocab_size, embedding_dim=embedding_dim, commitment_weight=commitment_weight, ema_decay=ema_decay, ) else: self.vq = VectorQuantizer( vocab_size=vocab_size, embedding_dim=embedding_dim, commitment_weight=commitment_weight, ) self.from_vq_embedding = nn.Linear(embedding_dim, decoder_dim) self.decoder_pos_embed = nn.Parameter( torch.zeros(1, num_frames * num_patches, decoder_dim) ) nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02) self.decoder = STTransformer( num_frames=num_frames, num_patches_per_frame=num_patches, dim=decoder_dim, depth=decoder_depth, num_heads=num_heads, drop_rate=0.0, attn_drop_rate=0.0, ) self.decoder_proj = nn.ConvTranspose2d( decoder_dim, in_channels, kernel_size=patch_size, stride=patch_size, )
[docs] def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]: """Encode video to discrete tokens. Args: x: Video tensor (B, C, T, H, W) Returns: z_q: Quantized embeddings (B, T, H', W', embedding_dim) indices: Token indices (B, T, H', W') vq_loss: Dictionary with VQ loss components """ B, C, T, H, W = x.shape x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W) x = self.patch_embed(x) _, C_enc, H_enc, W_enc = x.shape x = x.reshape(B, T, C_enc, H_enc, W_enc) x = x.permute(0, 1, 3, 4, 2).reshape(B, T * H_enc * W_enc, C_enc) seq_len = T * H_enc * W_enc pos_embed = self.encoder_pos_embed[:, :seq_len, :] x = x + pos_embed x = self.encoder(x) x = x.reshape(B, T, H_enc, W_enc, self.encoder_dim) z_all = [] indices_all = [] vq_loss_all: dict[str, list[torch.Tensor]] = {} for t in range(T): x_t = x[:, t, :, :, :] x_t = x_t.permute(0, 3, 1, 2).reshape(B, self.encoder_dim, H_enc * W_enc) x_t = x_t.permute(0, 2, 1) x_t_embed = self.to_vq_embedding(x_t) x_t_embed = x_t_embed.reshape(B, H_enc, W_enc, self.embedding_dim) x_t_embed = x_t_embed.permute(0, 3, 1, 2) z_q_t, indices_t, vq_loss_t = self.vq(x_t_embed) z_all.append(z_q_t.reshape(B, H_enc, W_enc, self.embedding_dim)) indices_all.append(indices_t.reshape(B, H_enc, W_enc)) for k, v in vq_loss_t.items(): if k not in vq_loss_all: vq_loss_all[k] = [] vq_loss_all[k].append(v) z_q = torch.stack(z_all, dim=1) indices = torch.stack(indices_all, dim=1) vq_loss = {} for k, v in vq_loss_all.items(): stacked = torch.stack(v) if k == "perplexity": vq_loss[k] = stacked.mean() else: vq_loss[k] = stacked.mean() return z_q, indices, vq_loss
[docs] def decode_indices(self, indices: torch.Tensor) -> torch.Tensor: """Decode token indices to embeddings for video frames. Args: indices: Token indices (B, T, H', W') or (B, T, N) where N = H'*W' Returns: z_q: Quantized embeddings (B, T, H', W', embedding_dim) """ if indices.dim() == 4: B, T, H_idx, W_idx = indices.shape indices_reshaped = indices elif indices.dim() == 3: B, T, N = indices.shape num_patches = int(N**0.5) if num_patches**2 == N: H_idx = W_idx = num_patches indices_reshaped = indices.reshape(B, T, H_idx, W_idx) else: indices_reshaped = indices H_idx = W_idx = 1 else: raise ValueError(f"Expected 3D or 4D indices, got {indices.dim()}D") z_q = F.embedding(indices_reshaped, self.vq.codebook.weight) z_q = z_q.reshape(B, T, H_idx, W_idx, self.embedding_dim) return z_q
[docs] def decode(self, z_q: torch.Tensor) -> torch.Tensor: """Decode discrete tokens to video frames. Args: z_q: Quantized embeddings (B, T, H', W', embedding_dim) Returns: Reconstructed video (B, C, T, H, W) """ B, T, H_dec, W_dec, _ = z_q.shape x = z_q.reshape(B, T * H_dec * W_dec, -1) x = self.from_vq_embedding(x) seq_len = T * H_dec * W_dec pos_embed = self.decoder_pos_embed[:, :seq_len, :] x = x + pos_embed x = self.decoder(x) x = x.reshape(B, T, H_dec, W_dec, self.decoder_dim) x = x.permute(0, 4, 1, 2, 3).reshape(B * T, self.decoder_dim, H_dec, W_dec) x = self.decoder_proj(x) x = x.reshape(B, T, x.shape[1], x.shape[2], x.shape[3]) x = x.permute(0, 2, 1, 3, 4) return x
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]: """Full forward pass with VQ-VAE objective. Args: x: Video tensor (B, C, T, H, W) Returns: reconstructed: Reconstructed video (B, C, T, H, W) indices: Token indices (B, T, H', W') loss_dict: Dictionary containing loss components """ z_q, indices, vq_loss = self.encode(x) reconstructed = self.decode(z_q) recon_loss = F.mse_loss(reconstructed, x) loss_dict = { "recon_loss": recon_loss, "vq_loss": vq_loss["vq_loss"], "perplexity": vq_loss["perplexity"], } return reconstructed, indices, loss_dict
[docs] def create_video_tokenizer( num_frames: int = 16, image_size: int = 64, in_channels: int = 3, encoder_dim: int = 512, decoder_dim: int = 1024, encoder_depth: int = 12, decoder_depth: int = 20, num_heads: int = 16, patch_size: int = 4, vocab_size: int = 1024, embedding_dim: int = 32, use_ema: bool = False, ) -> VideoTokenizer: """Factory function to create a Video Tokenizer.""" return VideoTokenizer( num_frames=num_frames, image_size=image_size, in_channels=in_channels, encoder_dim=encoder_dim, decoder_dim=decoder_dim, encoder_depth=encoder_depth, decoder_depth=decoder_depth, num_heads=num_heads, patch_size=patch_size, vocab_size=vocab_size, embedding_dim=embedding_dim, use_ema=use_ema, )