import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
[docs]
class VectorQuantizer(nn.Module):
"""Vector Quantizer for discrete autoencoder.
Implements the VQ-VAE quantization from:
"Neural Discrete Representation Learning" (Van Den Oord et al., 2017)
Uses exponential moving averages for codebook updates and straight-through
estimator for gradient flow.
"""
def __init__(
self,
vocab_size: int = 512,
embedding_dim: int = 512,
commitment_weight: float = 0.25,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.commitment_weight = commitment_weight
# Codebook: learnable embeddings
self.codebook = nn.Embedding(vocab_size, embedding_dim)
self.codebook.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
[docs]
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]:
"""Quantize the input latents.
Args:
z: Input tensor of shape (B, C, H, W) or (B, C)
Returns:
z_q: Quantized tensor (same shape as input)
indices: Token indices for each position (B, H, W) or (B,)
loss: Dictionary containing VQ loss components
"""
# Reshape for quantization
original_shape = z.shape
if z.dim() == 4: # (B, C, H, W)
B, C, H, W = z.shape
# Flatten spatial dimensions: (B, C, H*W) -> (B, H*W, C)
z_flat = z.permute(0, 2, 3, 1).reshape(B, H * W, C)
elif z.dim() == 2: # (B, C)
B = z.shape[0]
z_flat = z.unsqueeze(1) # (B, 1, C)
else:
raise ValueError(f"Expected 2D or 4D input, got {z.dim()}D")
# Compute distances to codebook entries
# ||z - e||^2 = ||z||^2 + ||e||^2 - 2 * z·e
z_flat = z_flat.float()
codebook = self.codebook.weight.float()
d = (
torch.sum(z_flat**2, dim=-1, keepdim=True)
+ torch.sum(codebook**2, dim=-1)
- 2 * torch.matmul(z_flat, codebook.t())
) # (B, H*W, vocab_size)
# Find nearest codebook entries (indices)
indices = torch.argmin(d, dim=-1) # (B, H*W) or (B, 1)
# Get the quantized values (straight-through)
z_q = F.embedding(indices, codebook)
# Straight-through estimator: use z_q for forward, z for backward
# This allows gradients to flow through while using discrete codes
z_q_detached = z_q.detach()
# Compute losses
# 1. Reconstruction loss: ||z - z_q||^2 (stop gradient on z_q)
commitment_loss = F.mse_loss(z_q_detached, z_flat)
# 2. Codebook loss: encourage z_q to move toward z (stop gradient on z)
# Actually this is handled by the commitment loss since we use detached z_q
# 3. Perplexity: measure of how many codebook entries are used
encodings = F.one_hot(indices, self.vocab_size).float()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
# Reshape back to original spatial dimensions
if len(original_shape) == 4:
z_q = z_q.permute(0, 2, 3, 1).reshape(B, C, H, W)
else:
z_q = z_q.squeeze(1)
indices_reshaped = (
indices.reshape(B, H, W) if z.dim() == 4 else indices.squeeze(-1)
)
loss = {
"vq_loss": commitment_loss,
"perplexity": perplexity,
}
return z_q, indices_reshaped, loss
[docs]
def decode_indices(self, indices: torch.Tensor) -> torch.Tensor:
"""Decode token indices back to embeddings.
Args:
indices: Token indices (B, H, W) or (B,)
Returns:
Embeddings (B, C, H, W) or (B, C)
"""
if indices.dim() == 3: # (B, H, W)
B, H, W = indices.shape
z_q = F.embedding(indices, self.codebook.weight) # (B, H*W, C)
z_q = z_q.permute(0, 2, 1).reshape(B, -1, H, W)
else:
z_q = F.embedding(indices, self.codebook.weight)
return z_q
[docs]
class VectorQuantizerEMA(nn.Module):
"""Vector Quantizer with Exponential Moving Average updates.
Uses EMA updates for the codebook instead of gradient-based updates,
which leads to more stable training.
"""
def __init__(
self,
vocab_size: int = 512,
embedding_dim: int = 512,
commitment_weight: float = 0.25,
ema_decay: float = 0.99,
epsilon: float = 1e-5,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.commitment_weight = commitment_weight
self.ema_decay = ema_decay
self.epsilon = epsilon
# Codebook
self.codebook = nn.Embedding(vocab_size, embedding_dim)
self.codebook.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
# EMA tracking
self.register_buffer("ema_cluster_size", torch.zeros(vocab_size))
self.register_buffer("ema_embed_avg", self.codebook.weight.data.clone())
[docs]
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, dict]:
"""Quantize with EMA updates."""
# Flatten spatial dims
B, C, H, W = z.shape
z_flat = z.permute(0, 2, 3, 1).reshape(B, H * W, C).float()
# Compute distances
codebook = self.codebook.weight.float()
d = (
torch.sum(z_flat**2, dim=-1, keepdim=True)
+ torch.sum(codebook**2, dim=-1)
- 2 * torch.matmul(z_flat, codebook.t())
)
indices = torch.argmin(d, dim=-1)
# Quantize (using straight-through)
z_q = F.embedding(indices, codebook)
# EMA update (only during training)
if self.training:
with torch.no_grad():
encodings = F.one_hot(indices, self.vocab_size).float()
self.ema_cluster_size.mul_(self.ema_decay).add_(
encodings.sum(dim=(0, 1, 2)), alpha=1 - self.ema_decay
)
# Update cluster averages
n = self.ema_cluster_size.sum()
new_ema_embed_avg = (
self.ema_embed_avg * self.ema_decay
+ (z_flat.transpose(1, 2) @ encodings).sum(0) * (1 - self.ema_decay)
) / (n + self.epsilon)
self.ema_embed_avg.copy_(new_ema_embed_avg)
# Normalize and update codebook properly (avoid breaking gradient graph)
normalized = F.normalize(self.ema_embed_avg, dim=1)
self.codebook.weight.data.copy_(normalized)
# Reshape back
z_q = z_q.reshape(B, H, W, C).permute(0, 3, 1, 2)
# Compute losses
# Codebook loss: move codebook toward encoder output (stop gradient on encoder)
codebook_loss = F.mse_loss(z_q, z.detach())
# Commitment loss: move encoder output toward codebook (stop gradient on codebook)
commitment_loss = F.mse_loss(z_q.detach(), z)
# Perplexity
encodings = F.one_hot(indices, self.vocab_size).float()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
loss = {
"vq_loss": codebook_loss + commitment_loss * self.commitment_weight,
"perplexity": perplexity,
}
return z_q, indices.reshape(B, H, W), loss
[docs]
def decode_indices(self, indices: torch.Tensor) -> torch.Tensor:
"""Decode token indices to embeddings.
Args:
indices: Token indices (B, H, W) or (B,)
Returns:
Embeddings (B, C, H, W) or (B, C)
"""
if indices.dim() == 3: # (B, H, W)
B, H, W = indices.shape
indices_flat = indices.reshape(B, -1) # (B, H*W)
z_q = F.embedding(indices_flat, self.codebook.weight) # (B, H*W, C)
z_q = z_q.permute(0, 2, 1).reshape(B, -1, H, W)
else:
z_q = F.embedding(indices, self.codebook.weight)
return z_q