import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict, Literal
from world_models.vision.video_tokenizer import VideoTokenizer
from world_models.models.latent_action_model import LatentActionModel
from world_models.models.dynamics_model import DynamicsModel, MaskGITSampler
[docs]
class Genie(nn.Module):
"""Genie: Generative Interactive Environment.
A generative model trained from video-only data that can be used as an
interactive environment. Contains three key components:
1. Video Tokenizer: Converts raw video frames into discrete tokens
2. Latent Action Model (LAM): Infers latent actions between frames
3. Dynamics Model: Predicts future frames given past frames and latent actions
Based on "Genie: Generative Interactive Environments" paper (arXiv:2402.15391).
Training follows two phases as per paper:
1. Train video tokenizer first (on video tokens)
2. Co-train LAM (from pixels) and dynamics model (on video tokens)
The LAM uses VQ-VAE training with:
- Encoder: Takes x1:t and x_{t+1} → outputs latent actions
- Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t
- Auxiliary variance loss to prevent action collapse
At inference, latent actions are stopgrad'd when passed to dynamics model.
"""
def __init__(
self,
num_frames: int = 16,
image_size: int = 64,
in_channels: int = 3,
tokenizer_vocab_size: int = 1024,
tokenizer_embedding_dim: int = 32,
tokenizer_encoder_dim: int = 512,
tokenizer_decoder_dim: int = 1024,
action_vocab_size: int = 8,
action_embedding_dim: int = 32,
action_encoder_dim: int = 1024,
action_decoder_dim: int = 1024,
dynamics_dim: int = 5120,
dynamics_depth: int = 48,
dynamics_num_heads: int = 36,
encoder_depth: int = 12,
decoder_depth: int = 20,
latent_action_depth: int = 20,
use_bfloat16: bool = False,
action_pooling: Literal["mean", "windowed_attention"] = "mean",
window_attention_heads: int = 1,
):
super().__init__()
self.num_frames = num_frames
self.image_size = image_size
self.tokenizer_vocab_size = tokenizer_vocab_size
self.action_vocab_size = action_vocab_size
self.use_bfloat16 = use_bfloat16
# Video Tokenizer (VQ-VAE with ST-Transformer)
self.video_tokenizer = VideoTokenizer(
num_frames=num_frames,
image_size=image_size,
in_channels=in_channels,
encoder_dim=tokenizer_encoder_dim,
decoder_dim=tokenizer_decoder_dim,
encoder_depth=encoder_depth,
decoder_depth=decoder_depth,
num_heads=16,
patch_size=4,
vocab_size=tokenizer_vocab_size,
embedding_dim=tokenizer_embedding_dim,
use_ema=False,
)
# Latent Action Model (VQ-VAE with ST-Transformer decoder)
self.latent_action_model = LatentActionModel(
num_frames=num_frames,
image_size=image_size,
in_channels=in_channels,
encoder_dim=action_encoder_dim,
decoder_dim=action_decoder_dim,
encoder_depth=latent_action_depth,
decoder_depth=latent_action_depth,
num_heads=16,
patch_size=16,
vocab_size=action_vocab_size,
embedding_dim=action_embedding_dim,
commitment_weight=1.0,
action_pooling=action_pooling,
window_attention_heads=window_attention_heads,
)
# Dynamics Model (MaskGIT transformer)
self.dynamics_model = DynamicsModel(
num_frames=num_frames,
image_size=image_size,
vocab_size=tokenizer_vocab_size,
embedding_dim=tokenizer_embedding_dim,
action_vocab_size=action_vocab_size,
dim=dynamics_dim,
depth=dynamics_depth,
num_heads=dynamics_num_heads,
patch_size=4,
)
# MaskGIT sampler for inference
self.sampler = MaskGITSampler(num_steps=25, temperature=2.0)
[docs]
def forward(
self,
video: torch.Tensor,
mask_prob: float = 0.5,
training_phase: str = "all",
) -> Dict[str, torch.Tensor]:
"""Full forward pass through all components.
Args:
video: (B, C, T, H, W) input video
mask_prob: Probability for random masking in dynamics (0.5-1.0)
training_phase: "all", "tokenizer", or "lam_dynamics"
Returns:
Dictionary containing losses and predictions
"""
B, _, T, _, _ = video.shape
if training_phase == "tokenizer":
# Phase 1: Train only video tokenizer
recon_video, video_indices, tokenizer_loss_dict = self.video_tokenizer(
video
)
return {
"reconstructed_video": recon_video,
"video_indices": video_indices,
"tokenizer_loss": tokenizer_loss_dict,
"vq_loss": tokenizer_loss_dict["vq_loss"],
"total_loss": tokenizer_loss_dict["recon_loss"]
+ tokenizer_loss_dict["vq_loss"],
}
# Phase 2 or 3: Get video tokens first (frozen or training)
with (
torch.no_grad() if training_phase == "lam_dynamics" else torch.enable_grad()
):
recon_video, video_indices, tokenizer_loss_dict = self.video_tokenizer(
video
)
video_tokens = video_indices.reshape(B, T, -1) # (B, T, H*W)
# ===== LATENT ACTION MODEL =====
# Train LAM from pixels - includes encoder + decoder losses
lam_output = self.latent_action_model(
video[:, :, :-1], # x1:T-1
video[:, :, -1], # x_T
)
# Get latent actions - apply stopgrad for dynamics (as per paper)
latent_actions = lam_output["latent_actions"] # (B, T-1)
z_q = lam_output["z_q"] # (B, T-1, embedding_dim)
# stopgrad on latent actions when passing to dynamics (per paper Section 2.1)
z_q_for_dynamics = z_q.detach()
# Map z_q to action indices for dynamics model
# z_q is (B, T-1, embedding_dim), we need (B, T-1) indices
# Use the latent_actions directly
actions_for_dynamics = latent_actions[:, : T - 1]
# ===== DYNAMICS MODEL =====
# Predict next frame tokens given past tokens and latent actions
# Input: video_tokens[:, :-1] (past frames), actions_for_dynamics
# Target: video_tokens[:, 1:] (next frames)
target_tokens = video_tokens[:, 1:, :] # (B, T-1, H*W)
dynamics_logits = self.dynamics_model(
video_tokens[:, :-1, :],
actions_for_dynamics,
mask_prob=mask_prob,
)
# Compute dynamics loss
B_pred, T_pred, N, V = dynamics_logits.shape
target_flat = target_tokens.reshape(B_pred * T_pred * N)
logits_flat = dynamics_logits.reshape(B_pred * T_pred * N, V)
dynamics_loss = F.cross_entropy(logits_flat, target_flat)
# ===== TOTAL LOSS =====
# According to paper: co-train LAM and dynamics
# LAM losses: recon_loss (from decoder) + vq_loss + variance_loss
# Dynamics loss: cross-entropy on token prediction
lam_recon_loss = lam_output["recon_loss"]
lam_vq_loss = lam_output["vq_loss"]
lam_variance_loss = lam_output["variance_loss"]
total_loss = (
lam_recon_loss
+ lam_vq_loss
+ lam_variance_loss
+ dynamics_loss
+ tokenizer_loss_dict["recon_loss"]
+ tokenizer_loss_dict["vq_loss"]
)
return {
"reconstructed_video": recon_video,
"video_indices": video_indices,
"latent_actions": latent_actions,
"lam_reconstructed": lam_output["reconstructed"],
"dynamics_logits": dynamics_logits,
"tokenizer_loss": tokenizer_loss_dict,
"vq_loss": tokenizer_loss_dict["vq_loss"],
"recon_loss": tokenizer_loss_dict["recon_loss"],
"lam_recon_loss": lam_recon_loss,
"lam_vq_loss": lam_vq_loss,
"lam_variance_loss": lam_variance_loss,
"dynamics_loss": dynamics_loss,
"z_q_for_dynamics": z_q_for_dynamics,
"total_loss": total_loss,
}
[docs]
def training_step(
self,
video: torch.Tensor,
mask_prob: float = 0.5,
training_phase: str = "all",
) -> Dict[str, torch.Tensor]:
"""Single training step computing all losses.
Args:
video: (B, C, T, H, W) input video
mask_prob: Probability for random masking in dynamics
training_phase: "all", "tokenizer", or "lam_dynamics"
Returns:
Dictionary containing all losses for backpropagation
"""
if self.use_bfloat16:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
return self.forward(video, mask_prob, training_phase)
return self.forward(video, mask_prob, training_phase)
[docs]
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
"""Encode video to discrete tokens.
Args:
video: (B, C, T, H, W)
Returns:
video_tokens: (B, T, H*W)
"""
_, video_indices, _ = self.video_tokenizer(video)
return video_indices.reshape(video_indices.shape[0], video_indices.shape[1], -1)
[docs]
def infer_actions(self, frames: torch.Tensor) -> torch.Tensor:
"""Infer latent actions from a sequence of frames.
Args:
frames: (B, C, T, H, W) video frames
Returns:
latent_actions: (B, T-1) inferred latent action indices
"""
lam_output = self.latent_action_model(
frames[:, :, :-1],
frames[:, :, -1],
)
return lam_output["latent_actions"]
[docs]
def generate(
self,
prompt_frame: torch.Tensor,
num_frames: int = 16,
actions: Optional[torch.Tensor] = None,
use_maskgit: bool = True,
) -> torch.Tensor:
"""Generate video frames given a prompt frame and actions.
Args:
prompt_frame: (B, C, H, W) initial frame
num_frames: Total number of frames to generate
actions: (B, num_frames-1) latent action indices, or None for random
use_maskgit: Whether to use MaskGIT sampling
Returns:
generated_video: (B, C, num_frames, H, W)
"""
B, _, _, _ = prompt_frame.shape
# Tokenize prompt frame
prompt_frame_expanded = prompt_frame.unsqueeze(2).expand(
-1, -1, num_frames, -1, -1
)
_, prompt_indices, _ = self.video_tokenizer(prompt_frame_expanded)
# Use first frame tokens as prompt
prompt_tokens = prompt_indices[:, 0, :, :].reshape(B, -1).unsqueeze(1)
# Sample random actions if not provided
if actions is None:
actions = torch.randint(
0,
self.action_vocab_size,
(B, num_frames - 1),
device=prompt_frame.device,
)
# Generate
if use_maskgit and hasattr(self, "sampler"):
generated_tokens = self._generate_maskgit(
prompt_tokens, actions, num_frames
)
else:
generated_tokens = self.dynamics_model.autoregressive_sample(
prompt_tokens[:, :1, :],
actions[:, :1],
num_frames,
temperature=2.0,
)
# Decode tokens to video
z_generated = self.video_tokenizer.decode_indices(generated_tokens)
generated_video = self.video_tokenizer.decode(z_generated)
return generated_video
def _generate_maskgit(
self,
prompt_tokens: torch.Tensor,
actions: torch.Tensor,
num_frames: int,
) -> torch.Tensor:
"""Generate tokens using MaskGIT sampling."""
B = prompt_tokens.shape[0]
num_patches = prompt_tokens.shape[2]
target_tokens = torch.zeros(
(B, num_frames, num_patches),
dtype=torch.long,
device=prompt_tokens.device,
)
target_tokens[:, 0, :] = prompt_tokens[:, 0, :]
mask = torch.ones(
B, num_frames, num_patches, dtype=torch.bool, device=prompt_tokens.device
)
mask[:, :1, :] = False
max_steps = min(self.sampler.num_steps, 3)
for step in range(max_steps):
logits = self.dynamics_model(
target_tokens[:, :-1, :],
actions,
mask_prob=0.0,
)
target_tokens, mask = self.sampler.sample(
logits.reshape(B, num_frames, num_patches, -1),
mask,
step,
)
return target_tokens
[docs]
def play(
self,
current_frame: torch.Tensor,
action: torch.Tensor,
current_frames: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Play step - generate next frame given current frame and action.
Args:
current_frame: (B, C, H, W) current frame
action: (B,) latent action indices
current_frames: (B, C, T, H, W) history frames, or None for first frame
Returns:
next_frame: (B, C, H, W)
"""
B, _, _, _ = current_frame.shape
if not isinstance(action, torch.Tensor):
action = torch.tensor(action, device=current_frame.device)
action = action.to(current_frame.device)
if current_frames is None:
current_frames = current_frame.unsqueeze(2)
T_history = current_frames.shape[2]
# Tokenize current frames
_, prompt_indices, _ = self.video_tokenizer(current_frames)
prompt_tokens = prompt_indices.reshape(B, T_history, -1)
if action.dim() == 0:
action = action.unsqueeze(0)
action_expanded = action.unsqueeze(1).expand(-1, T_history)
# Predict next frame
next_frame_logits = self.dynamics_model(
prompt_tokens,
action_expanded,
mask_prob=0.0,
)
next_frame_logits = next_frame_logits[:, -1, :, :]
next_token_ids = torch.argmax(next_frame_logits, dim=-1)
num_patches_per_side = int(next_token_ids.shape[1] ** 0.5)
next_token_ids_reshaped = next_token_ids.reshape(
B, num_patches_per_side, num_patches_per_side
)
z_next = self.video_tokenizer.decode_indices(
next_token_ids_reshaped.unsqueeze(1)
)
next_frame = self.video_tokenizer.decode(z_next)
return next_frame.squeeze(2)
[docs]
def get_num_parameters(self) -> int:
"""Return total number of parameters."""
return sum(p.numel() for p in self.parameters())
[docs]
def create_genie(
num_frames: int = 16,
image_size: int = 64,
in_channels: int = 3,
tokenizer_vocab_size: int = 1024,
tokenizer_embedding_dim: int = 32,
action_vocab_size: int = 8,
action_embedding_dim: int = 32,
dynamics_dim: int = 5120,
dynamics_depth: int = 48,
dynamics_num_heads: int = 36,
use_bfloat16: bool = False,
action_pooling: Literal["mean", "windowed_attention"] = "mean",
window_attention_heads: int = 1,
) -> Genie:
"""Factory function to create a Genie model."""
return Genie(
num_frames=num_frames,
image_size=image_size,
in_channels=in_channels,
tokenizer_vocab_size=tokenizer_vocab_size,
tokenizer_embedding_dim=tokenizer_embedding_dim,
action_vocab_size=action_vocab_size,
action_embedding_dim=action_embedding_dim,
dynamics_dim=dynamics_dim,
dynamics_depth=dynamics_depth,
dynamics_num_heads=dynamics_num_heads,
use_bfloat16=use_bfloat16,
action_pooling=action_pooling,
window_attention_heads=window_attention_heads,
)
[docs]
def create_genie_small(
num_frames: int = 16,
image_size: int = 64,
use_bfloat16: bool = False,
action_pooling: Literal["mean", "windowed_attention"] = "mean",
window_attention_heads: int = 1,
) -> Genie:
"""Create a smaller Genie model for development/testing."""
return Genie(
num_frames=num_frames,
image_size=image_size,
tokenizer_vocab_size=1024,
tokenizer_embedding_dim=32,
tokenizer_encoder_dim=256,
tokenizer_decoder_dim=512,
action_vocab_size=8,
action_embedding_dim=32,
action_encoder_dim=512,
action_decoder_dim=512,
dynamics_dim=512,
dynamics_depth=8,
dynamics_num_heads=8,
encoder_depth=4,
decoder_depth=8,
latent_action_depth=8,
use_bfloat16=use_bfloat16,
action_pooling=action_pooling,
window_attention_heads=window_attention_heads,
)
[docs]
def create_genie_large(
num_frames: int = 16,
image_size: int = 64,
use_bfloat16: bool = True,
action_pooling: Literal["mean", "windowed_attention"] = "mean",
window_attention_heads: int = 1,
) -> Genie:
"""Create the full 11B parameter Genie model (approximate)."""
return Genie(
num_frames=num_frames,
image_size=image_size,
tokenizer_vocab_size=1024,
tokenizer_embedding_dim=32,
tokenizer_encoder_dim=512,
tokenizer_decoder_dim=1024,
action_vocab_size=8,
action_embedding_dim=32,
action_encoder_dim=1024,
action_decoder_dim=1024,
dynamics_dim=5120,
dynamics_depth=48,
dynamics_num_heads=36,
encoder_depth=12,
decoder_depth=20,
latent_action_depth=20,
use_bfloat16=use_bfloat16,
action_pooling=action_pooling,
window_attention_heads=window_attention_heads,
)