Source code for world_models.vision.vq_layer

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple


[docs] class VectorQuantizer(nn.Module): """Vector Quantizer for discrete autoencoder. Implements the VQ-VAE quantization from: "Neural Discrete Representation Learning" (Van Den Oord et al., 2017) Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow. """ def __init__( self, vocab_size: int = 512, embedding_dim: int = 512, commitment_weight: float = 0.25, ): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.commitment_weight = commitment_weight # Codebook: learnable embeddings self.codebook = nn.Embedding(vocab_size, embedding_dim) self.codebook.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
[docs] def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]: """Quantize the input latents. Args: z: Input tensor of shape (B, C, H, W) or (B, C) Returns: z_q: Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components """ # Reshape for quantization original_shape = z.shape if z.dim() == 4: # (B, C, H, W) B, C, H, W = z.shape # Flatten spatial dimensions: (B, C, H*W) -> (B, H*W, C) z_flat = z.permute(0, 2, 3, 1).reshape(B, H * W, C) elif z.dim() == 2: # (B, C) B = z.shape[0] z_flat = z.unsqueeze(1) # (B, 1, C) else: raise ValueError(f"Expected 2D or 4D input, got {z.dim()}D") # Compute distances to codebook entries # ||z - e||^2 = ||z||^2 + ||e||^2 - 2 * z·e z_flat = z_flat.float() codebook = self.codebook.weight.float() d = ( torch.sum(z_flat**2, dim=-1, keepdim=True) + torch.sum(codebook**2, dim=-1) - 2 * torch.matmul(z_flat, codebook.t()) ) # (B, H*W, vocab_size) # Find nearest codebook entries (indices) indices = torch.argmin(d, dim=-1) # (B, H*W) or (B, 1) # Get the quantized values (straight-through) z_q = F.embedding(indices, codebook) # Straight-through estimator: use z_q for forward, z for backward # This allows gradients to flow through while using discrete codes z_q_detached = z_q.detach() # Compute losses # 1. Reconstruction loss: ||z - z_q||^2 (stop gradient on z_q) commitment_loss = F.mse_loss(z_q_detached, z_flat) # 2. Codebook loss: encourage z_q to move toward z (stop gradient on z) # Actually this is handled by the commitment loss since we use detached z_q # 3. Perplexity: measure of how many codebook entries are used encodings = F.one_hot(indices, self.vocab_size).float() avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # Reshape back to original spatial dimensions if len(original_shape) == 4: z_q = z_q.permute(0, 2, 3, 1).reshape(B, C, H, W) else: z_q = z_q.squeeze(1) indices_reshaped = ( indices.reshape(B, H, W) if z.dim() == 4 else indices.squeeze(-1) ) loss = { "vq_loss": commitment_loss, "perplexity": perplexity, } return z_q, indices_reshaped, loss
[docs] def decode_indices(self, indices: torch.Tensor) -> torch.Tensor: """Decode token indices back to embeddings. Args: indices: Token indices (B, H, W) or (B,) Returns: Embeddings (B, C, H, W) or (B, C) """ if indices.dim() == 3: # (B, H, W) B, H, W = indices.shape z_q = F.embedding(indices, self.codebook.weight) # (B, H*W, C) z_q = z_q.permute(0, 2, 1).reshape(B, -1, H, W) else: z_q = F.embedding(indices, self.codebook.weight) return z_q
[docs] class VectorQuantizerEMA(nn.Module): """Vector Quantizer with Exponential Moving Average updates. Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training. """ def __init__( self, vocab_size: int = 512, embedding_dim: int = 512, commitment_weight: float = 0.25, ema_decay: float = 0.99, epsilon: float = 1e-5, ): super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.commitment_weight = commitment_weight self.ema_decay = ema_decay self.epsilon = epsilon # Codebook self.codebook = nn.Embedding(vocab_size, embedding_dim) self.codebook.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size) # EMA tracking self.register_buffer("ema_cluster_size", torch.zeros(vocab_size)) self.register_buffer("ema_embed_avg", self.codebook.weight.data.clone())
[docs] def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]: """Quantize with EMA updates.""" # Flatten spatial dims B, C, H, W = z.shape z_flat = z.permute(0, 2, 3, 1).reshape(B, H * W, C).float() # Compute distances codebook = self.codebook.weight.float() d = ( torch.sum(z_flat**2, dim=-1, keepdim=True) + torch.sum(codebook**2, dim=-1) - 2 * torch.matmul(z_flat, codebook.t()) ) indices = torch.argmin(d, dim=-1) # Quantize (using straight-through) z_q = F.embedding(indices, codebook) # EMA update (only during training) if self.training: with torch.no_grad(): encodings = F.one_hot(indices, self.vocab_size).float() self.ema_cluster_size.mul_(self.ema_decay).add_( encodings.sum(dim=(0, 1, 2)), alpha=1 - self.ema_decay ) # Update cluster averages n = self.ema_cluster_size.sum() new_ema_embed_avg = ( self.ema_embed_avg * self.ema_decay + (z_flat.transpose(1, 2) @ encodings).sum(0) * (1 - self.ema_decay) ) / (n + self.epsilon) self.ema_embed_avg.copy_(new_ema_embed_avg) # Normalize and update codebook properly (avoid breaking gradient graph) normalized = F.normalize(self.ema_embed_avg, dim=1) self.codebook.weight.data.copy_(normalized) # Reshape back z_q = z_q.reshape(B, H, W, C).permute(0, 3, 1, 2) # Compute losses # Codebook loss: move codebook toward encoder output (stop gradient on encoder) codebook_loss = F.mse_loss(z_q, z.detach()) # Commitment loss: move encoder output toward codebook (stop gradient on codebook) commitment_loss = F.mse_loss(z_q.detach(), z) # Perplexity encodings = F.one_hot(indices, self.vocab_size).float() avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) loss = { "vq_loss": codebook_loss + commitment_loss * self.commitment_weight, "perplexity": perplexity, } return z_q, indices.reshape(B, H, W), loss
[docs] def decode_indices(self, indices: torch.Tensor) -> torch.Tensor: """Decode token indices to embeddings. Args: indices: Token indices (B, H, W) or (B,) Returns: Embeddings (B, C, H, W) or (B, C) """ if indices.dim() == 3: # (B, H, W) B, H, W = indices.shape indices_flat = indices.reshape(B, -1) # (B, H*W) z_q = F.embedding(indices_flat, self.codebook.weight) # (B, H*W, C) z_q = z_q.permute(0, 2, 1).reshape(B, -1, H, W) else: z_q = F.embedding(indices, self.codebook.weight) return z_q