import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import Optional, Dict, Tuple, Literal
import numpy as np
from dataclasses import dataclass
from world_models.models.genie import Genie
[docs]
@dataclass
class GenieConfig:
"""Configuration for Genie training."""
num_frames: int = 16
image_size: int = 64
in_channels: int = 3
tokenizer_vocab_size: int = 1024
tokenizer_embedding_dim: int = 32
tokenizer_encoder_dim: int = 512
tokenizer_decoder_dim: int = 1024
tokenizer_encoder_depth: int = 12
tokenizer_decoder_depth: int = 20
action_vocab_size: int = 8
action_embedding_dim: int = 32
action_encoder_dim: int = 1024
action_encoder_depth: int = 20
action_pooling: Literal["mean", "windowed_attention"] = "mean"
window_attention_heads: int = 1
dynamics_dim: int = 512
dynamics_depth: int = 8
dynamics_num_heads: int = 8
batch_size: int = 4
learning_rate: float = 3e-5
weight_decay: float = 1e-4
warmup_steps: int = 5000
max_steps: int = 125000
mask_prob_min: float = 0.5
mask_prob_max: float = 1.0
sample_temperature: float = 2.0
maskgit_steps: int = 25
[docs]
class VideoDataset(Dataset):
"""Dataset for video data."""
def __init__(self, video_paths: list, num_frames: int = 16, image_size: int = 64):
self.video_paths = video_paths
self.num_frames = num_frames
self.image_size = image_size
def __len__(self) -> int:
return len(self.video_paths)
def __getitem__(self, idx: int) -> torch.Tensor:
raise NotImplementedError("Implement loading video from paths")
[docs]
class GenieTrainer:
"""Trainer for Genie model."""
def __init__(
self,
model: nn.Module,
config: GenieConfig,
device: Optional[torch.device] = None,
):
self.model = model
self.config = config
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
self.model.to(self.device)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay,
)
self.scheduler = self._create_scheduler()
self.global_step = 0
def _create_scheduler(self):
"""Create learning rate scheduler with warmup and cosine decay."""
warmup_steps = self.config.warmup_steps
max_steps = self.config.max_steps
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
progress = (step - warmup_steps) / (max_steps - warmup_steps)
return 0.5 * (1.0 + np.cos(np.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
[docs]
def train_step(self, batch: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Single training step.
Args:
batch: (B, C, T, H, W) video batch
Returns:
Dictionary of losses
"""
self.model.train()
B, C, T, H, W = batch.shape
mask_prob = (
torch.rand(1).item()
* (self.config.mask_prob_max - self.config.mask_prob_min)
+ self.config.mask_prob_min
)
outputs = self.model(batch, mask_prob=mask_prob)
recon_loss = outputs.get("recon_loss", 0.0)
vq_loss = outputs.get("vq_loss", 0.0)
dynamics_loss = outputs.get(
"dynamics_loss", torch.tensor(0.0, device=self.device)
)
z_q_for_dynamics = outputs.get("z_q_for_dynamics", None)
z_q_for_dynamics_mean = (
z_q_for_dynamics.mean().item()
if isinstance(z_q_for_dynamics, torch.Tensor)
else None
)
total_loss = outputs["total_loss"]
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
self.global_step += 1
# Normalize all returned metrics to torch.Tensor for consistent typing
def as_tensor(x):
if isinstance(x, torch.Tensor):
return x.detach().cpu()
try:
return torch.tensor(float(x))
except Exception:
return torch.tensor(float("nan"))
return {
"total_loss": as_tensor(total_loss),
"recon_loss": as_tensor(recon_loss),
"vq_loss": as_tensor(vq_loss),
"dynamics_loss": as_tensor(dynamics_loss),
"learning_rate": torch.tensor(float(self.scheduler.get_last_lr()[0])),
"z_q_for_dynamics_mean": as_tensor(z_q_for_dynamics_mean),
}
[docs]
def validate(self, val_batch: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Validation step.
Args:
val_batch: (B, C, T, H, W) validation video batch
Returns:
Dictionary of validation metrics
"""
self.model.eval()
with torch.no_grad():
outputs = self.model(val_batch, mask_prob=0.0)
recon_loss = outputs["tokenizer_loss"].get("recon_loss", 0.0)
return {
"val_recon_loss": recon_loss.detach().cpu()
if isinstance(recon_loss, torch.Tensor)
else torch.tensor(float(recon_loss)),
}
[docs]
def train(
self,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
num_steps: Optional[int] = None,
log_interval: int = 100,
val_interval: int = 1000,
):
"""Full training loop.
Args:
train_dataloader: Training data loader
val_dataloader: Validation data loader (optional)
num_steps: Number of training steps (uses config.max_steps if None)
log_interval: Logging frequency
val_interval: Validation frequency
"""
if num_steps is None:
num_steps = self.config.max_steps
train_iter = iter(train_dataloader)
while self.global_step < num_steps:
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_dataloader)
batch = next(train_iter)
batch = batch.to(self.device)
losses = self.train_step(batch)
if self.global_step % log_interval == 0:
print(
f"Step {self.global_step}/{num_steps} | "
f"Loss: {losses['total_loss']:.4f} | "
f"Recon: {losses['recon_loss']:.4f} | "
f"VQ: {losses['vq_loss']:.4f} | "
f"Dynamics: {losses['dynamics_loss']:.4f} | "
f"LR: {losses['learning_rate']:.6f}"
)
if val_dataloader is not None and self.global_step % val_interval == 0:
val_iter = iter(val_dataloader)
val_batch = next(val_iter)
val_batch = val_batch.to(self.device)
val_metrics = self.validate(val_batch)
print(f"Validation: {val_metrics}")
print("Training complete!")
[docs]
def save_checkpoint(self, path: str):
"""Save model checkpoint."""
torch.save(
{
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"global_step": self.global_step,
},
path,
)
[docs]
def load_checkpoint(self, path: str):
"""Load model checkpoint."""
checkpoint = torch.load(
path,
map_location=self.device,
weights_only=False,
)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.global_step = checkpoint["global_step"]
[docs]
def create_genie_trainer(
config: Optional[GenieConfig] = None,
device: Optional[torch.device] = None,
) -> Tuple[GenieTrainer, nn.Module]:
"""Factory function to create Genie trainer and model."""
if config is None:
config = GenieConfig()
model = Genie(
num_frames=config.num_frames,
image_size=config.image_size,
in_channels=config.in_channels,
tokenizer_vocab_size=config.tokenizer_vocab_size,
tokenizer_embedding_dim=config.tokenizer_embedding_dim,
action_vocab_size=config.action_vocab_size,
action_embedding_dim=config.action_embedding_dim,
dynamics_dim=config.dynamics_dim,
dynamics_depth=config.dynamics_depth,
dynamics_num_heads=config.dynamics_num_heads,
encoder_depth=config.tokenizer_encoder_depth,
decoder_depth=config.tokenizer_decoder_depth,
latent_action_depth=config.action_encoder_depth,
action_pooling=config.action_pooling,
window_attention_heads=config.window_attention_heads,
)
trainer = GenieTrainer(model, config, device)
return trainer, model