import torch
import numpy as np
from collections import defaultdict
import os
from tqdm import tqdm
import random
from world_models.configs.iris_config import IRISConfig
from world_models.models.iris_agent import IRISAgent
from world_models.memory.iris_memory import IRISReplayBuffer
from world_models.envs.ale_atari_env import make_atari_env
[docs]
class IRISTrainer:
"""Training loop for IRIS on Atari 100k benchmark."""
def __init__(
self,
game: str = "ALE/Pong-v5",
device: str = "cuda",
seed: int = 42,
config: IRISConfig = None,
):
self.game = game
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.seed = seed
# Set seeds
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Config
self.config = config if config is not None else IRISConfig()
# Create environment
self.env = make_atari_env(
game,
obs_type="rgb",
frameskip=4,
max_episode_steps=27000, # Standard Atari limit
)
# Get action space
self.action_size = self.env.action_space.n
# Create replay buffer
self.replay_buffer = IRISReplayBuffer(
size=100000, # 100k buffer
obs_shape=(3, 64, 64), # Resize frames to 64x64
action_size=self.action_size,
seq_len=self.config.transformer_timesteps,
batch_size=self.config.transformer_batch_size,
)
# Create agent
self.agent = IRISAgent(
config=self.config,
action_size=self.action_size,
device=self.device,
)
# Metrics
self.metrics = defaultdict(list)
[docs]
def preprocess_frame(self, frame: np.ndarray) -> np.ndarray:
"""Preprocess frame: resize to 64x64 and normalize."""
import cv2
frame = cv2.resize(frame, (64, 64), interpolation=cv2.INTER_LINEAR)
frame = frame.astype(np.float32) / 255.0
return frame.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
[docs]
def collect_experience(
self,
num_steps: int,
epsilon: float = 0.01,
) -> float:
"""Collect experience from environment.
Args:
num_steps: Number of steps to collect
epsilon: Random action probability
Returns:
Mean episode return
"""
obs, _ = self.env.reset()
obs = self.preprocess_frame(obs)
episode_returns = []
current_return = 0
steps_in_episode = 0
for step in range(num_steps):
# Choose action
if np.random.random() < epsilon:
action = self.env.action_space.sample()
else:
frame_tensor = (
torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device)
)
action = self.agent.act(frame_tensor, epsilon=0.0).item()
# Step environment
next_obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
next_obs = self.preprocess_frame(next_obs)
# Store in replay buffer
action_one_hot = np.zeros(self.action_size, dtype=np.float32)
action_one_hot[action] = 1.0
self.replay_buffer.add(
obs,
action_one_hot,
float(reward),
done,
)
current_return += reward
steps_in_episode += 1
if done:
episode_returns.append(current_return)
current_return = 0
steps_in_episode = 0
obs, _ = self.env.reset()
obs = self.preprocess_frame(obs)
else:
obs = next_obs
return np.mean(episode_returns) if episode_returns else 0.0
[docs]
def train_epoch(self, epoch: int) -> dict:
"""Train for one epoch.
Args:
epoch: Current epoch number
Returns:
Dictionary of metrics
"""
metrics = {}
# Get epsilon with decay for better exploration early on
epsilon = self.get_epsilon(epoch)
# Phase 1: Collect experience
mean_return = self.collect_experience(
num_steps=self.config.env_steps_per_epoch,
epsilon=epsilon,
)
metrics["collection_return"] = mean_return
# Only update components after warm-start periods
if epoch >= self.config.start_autoencoder_after:
# Phase 2: Update autoencoder
for _ in range(self.config.training_steps_per_epoch):
# Sample random frames
indices = np.random.randint(
0, len(self.replay_buffer), size=self.config.autoencoder_batch_size
)
frames = (
torch.tensor(
np.array([self.replay_buffer.observations[i] for i in indices]),
dtype=torch.float32,
).to(self.device)
/ 255.0
)
ae_metrics = self.agent.update_autoencoder(frames)
metrics["recon_loss"] = ae_metrics.get("recon_loss", 0)
metrics["vq_loss"] = ae_metrics.get("vq_loss", 0)
metrics["perplexity"] = ae_metrics.get("perplexity", 0)
if (
epoch >= self.config.start_transformer_after
and len(self.replay_buffer) >= self.config.transformer_timesteps + 1
):
# Phase 3: Update transformer
obs, acts, rews, terms = self.replay_buffer.sample_sequence()
obs_tensor = torch.tensor(obs, dtype=torch.float32).to(self.device) / 255.0
acts_tensor = torch.tensor(acts, dtype=torch.float32).to(self.device)
rews_tensor = torch.tensor(rews, dtype=torch.float32).to(self.device)
terms_tensor = torch.tensor(terms, dtype=torch.long).to(self.device)
tf_metrics = self.agent.update_transformer(
obs_tensor, acts_tensor, rews_tensor, terms_tensor
)
metrics["token_loss"] = tf_metrics.get("token_loss", 0)
metrics["reward_loss"] = tf_metrics.get("reward_loss", 0)
if (
epoch >= self.config.start_actor_critic_after
and len(self.replay_buffer) >= 50
):
# Phase 4: Update actor-critic in imagination
# Sample initial frames for imagination
sample_size = self.config.actor_critic_batch_size
indices = np.random.randint(0, len(self.replay_buffer), size=sample_size)
initial_frames = (
torch.tensor(
np.array([self.replay_buffer.observations[i] for i in indices]),
dtype=torch.float32,
).to(self.device)
/ 255.0
)
# Generate imagined trajectories
imagined = self.agent.imagine_rollout(
initial_frame=initial_frames,
horizon=self.config.imagination_horizon,
)
# Update policy
ac_metrics = self.agent.update_actor_critic(imagined)
metrics["actor_loss"] = ac_metrics.get("actor_loss", 0)
metrics["value_loss"] = ac_metrics.get("value_loss", 0)
metrics["entropy"] = ac_metrics.get("entropy", 0)
self.agent.current_epoch = epoch
self.agent.global_step += self.config.env_steps_per_epoch
return metrics
[docs]
def get_epsilon(self, epoch: int) -> float:
"""Get exploration epsilon with decay."""
start_epsilon = 0.5
min_epsilon = self.config.collect_epsilon
decay_epochs = 50
epsilon = max(
min_epsilon,
start_epsilon - (start_epsilon - min_epsilon) * epoch / decay_epochs,
)
return epsilon
[docs]
def evaluate(self, num_episodes: int = 100, render: bool = False):
"""Evaluate agent performance.
Args:
num_episodes: Number of evaluation episodes
render: If True, also return video frames and per-step latent vectors
Returns:
If render is False (default): dict with evaluation metrics
If render is True: tuple (episode_returns_array, videos_list, latents_array)
"""
episode_returns = []
videos: list[list[np.ndarray]] = []
latents_all: list[np.ndarray] = []
for _ in range(num_episodes):
raw_obs, _ = self.env.reset()
obs = self.preprocess_frame(raw_obs)
episode_return = 0
done = False
frames: list[np.ndarray] = []
while not done:
# Prepare frame for policy (CHW, float32, 0-1)
frame_tensor = (
torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device)
)
action = self.agent.act(
frame_tensor, epsilon=0.0, temperature=self.config.eval_temperature
).item()
next_raw, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
# Store raw frame for video (as HWC uint8 if possible)
try:
frames.append(np.asarray(next_raw))
except Exception:
# Fallback: convert processed obs back to HWC
proc = np.asarray(obs)
if proc.ndim == 3:
# CHW -> HWC
frames.append(proc.transpose(1, 2, 0))
# Compute latent embedding via encoder (quantized embeddings)
try:
proc_frame = self.preprocess_frame(next_raw) if not done else obs
with torch.no_grad():
ft = (
torch.tensor(proc_frame, dtype=torch.float32)
.unsqueeze(0)
.to(self.device)
)
z_q, _, _ = self.agent.encoder(ft)
# z_q: (B, C, H', W') -> reduce spatial dims and take mean over channels
latent = z_q.mean(dim=(2, 3)).squeeze(0).cpu().numpy()
latents_all.append(latent.astype(np.float32))
except Exception:
# If encoder fails, skip latent for this step
pass
episode_return += reward
obs = self.preprocess_frame(next_raw) if not done else obs
episode_returns.append(episode_return)
videos.append(frames)
if render:
# Stack latents into (N, D) array if any
if latents_all:
latents_array = np.vstack(latents_all).astype(np.float32)
else:
latents_array = np.empty((0,), dtype=np.float32)
return np.array(episode_returns), videos, latents_array
# Non-render fallback: return simple metrics dict for compatibility
return {
"eval_mean_return": float(
np.mean(episode_returns) if episode_returns else 0.0
),
"eval_std_return": float(
np.std(episode_returns) if episode_returns else 0.0
),
"eval_max_return": float(
np.max(episode_returns) if episode_returns else 0.0
),
"eval_min_return": float(
np.min(episode_returns) if episode_returns else 0.0
),
}
[docs]
def train(
self,
total_epochs: int = None,
eval_interval: int = 50,
save_dir: str = "checkpoints/iris",
):
"""Full training loop.
Args:
total_epochs: Total training epochs
eval_interval: Evaluate every N epochs
save_dir: Directory to save checkpoints
"""
if total_epochs is None:
total_epochs = self.config.total_epochs
os.makedirs(save_dir, exist_ok=True)
print(f"Starting training for {total_epochs} epochs on {self.game}")
print(f"Action space: {self.action_size}")
print(f"Device: {self.device}")
best_eval_return = float("-inf")
for epoch in tqdm(range(total_epochs), desc="Training"):
# Train one epoch
metrics = self.train_epoch(epoch)
# Log metrics
for key, value in metrics.items():
self.metrics[key].append(value)
# Print progress
if epoch % 10 == 0:
print(f"\nEpoch {epoch}:")
for key, value in metrics.items():
print(f" {key}: {value:.4f}")
# Evaluate periodically
if (
epoch % eval_interval == 0
and epoch >= self.config.start_actor_critic_after
):
eval_metrics = self.evaluate(num_episodes=self.config.eval_episodes)
print(f"\nEvaluation at epoch {epoch}:")
print(
f" Mean return: {eval_metrics['eval_mean_return']:.2f} +/- {eval_metrics['eval_std_return']:.2f}"
)
# Save best model
if eval_metrics["eval_mean_return"] > best_eval_return:
best_eval_return = eval_metrics["eval_mean_return"]
save_path = os.path.join(
save_dir, f"best_{self.game.split('/')[-1]}.pt"
)
self.agent.save(save_path)
print(f" Saved best model: {save_path}")
# Checkpoint periodically
if epoch % self.config.checkpoint_interval == 0:
save_path = os.path.join(save_dir, f"checkpoint_{epoch}.pt")
self.agent.save(save_path)
print(f"\nTraining complete! Best eval return: {best_eval_return:.2f}")
return self.metrics
[docs]
def main():
"""Run IRIS training on a single Atari game."""
import argparse
parser = argparse.ArgumentParser(description="Train IRIS on Atari")
parser.add_argument("--game", type=str, default="ALE/Pong-v5", help="Atari game")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
parser.add_argument("--epochs", type=int, default=600, help="Total epochs")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--save_dir", type=str, default="checkpoints/iris", help="Save directory"
)
args = parser.parse_args()
trainer = IRISTrainer(
game=args.game,
device=args.device,
seed=args.seed,
)
trainer.train(
total_epochs=args.epochs,
save_dir=args.save_dir,
)
if __name__ == "__main__":
main()