Source code for world_models.training.train_iris

import torch
import numpy as np
from collections import defaultdict
import os
from tqdm import tqdm
import random

from world_models.configs.iris_config import IRISConfig
from world_models.models.iris_agent import IRISAgent
from world_models.memory.iris_memory import IRISReplayBuffer
from world_models.envs.ale_atari_env import make_atari_env


[docs] class IRISTrainer: """Training loop for IRIS on Atari 100k benchmark.""" def __init__( self, game: str = "ALE/Pong-v5", device: str = "cuda", seed: int = 42, config: IRISConfig = None, ): self.game = game self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.seed = seed # Set seeds random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Config self.config = config if config is not None else IRISConfig() # Create environment self.env = make_atari_env( game, obs_type="rgb", frameskip=4, max_episode_steps=27000, # Standard Atari limit ) # Get action space self.action_size = self.env.action_space.n # Create replay buffer self.replay_buffer = IRISReplayBuffer( size=100000, # 100k buffer obs_shape=(3, 64, 64), # Resize frames to 64x64 action_size=self.action_size, seq_len=self.config.transformer_timesteps, batch_size=self.config.transformer_batch_size, ) # Create agent self.agent = IRISAgent( config=self.config, action_size=self.action_size, device=self.device, ) # Metrics self.metrics = defaultdict(list)
[docs] def preprocess_frame(self, frame: np.ndarray) -> np.ndarray: """Preprocess frame: resize to 64x64 and normalize.""" import cv2 frame = cv2.resize(frame, (64, 64), interpolation=cv2.INTER_LINEAR) frame = frame.astype(np.float32) / 255.0 return frame.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
[docs] def collect_experience( self, num_steps: int, epsilon: float = 0.01, ) -> float: """Collect experience from environment. Args: num_steps: Number of steps to collect epsilon: Random action probability Returns: Mean episode return """ obs, _ = self.env.reset() obs = self.preprocess_frame(obs) episode_returns = [] current_return = 0 steps_in_episode = 0 for step in range(num_steps): # Choose action if np.random.random() < epsilon: action = self.env.action_space.sample() else: frame_tensor = ( torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device) ) action = self.agent.act(frame_tensor, epsilon=0.0).item() # Step environment next_obs, reward, terminated, truncated, info = self.env.step(action) done = terminated or truncated next_obs = self.preprocess_frame(next_obs) # Store in replay buffer action_one_hot = np.zeros(self.action_size, dtype=np.float32) action_one_hot[action] = 1.0 self.replay_buffer.add( obs, action_one_hot, float(reward), done, ) current_return += reward steps_in_episode += 1 if done: episode_returns.append(current_return) current_return = 0 steps_in_episode = 0 obs, _ = self.env.reset() obs = self.preprocess_frame(obs) else: obs = next_obs return np.mean(episode_returns) if episode_returns else 0.0
[docs] def train_epoch(self, epoch: int) -> dict: """Train for one epoch. Args: epoch: Current epoch number Returns: Dictionary of metrics """ metrics = {} # Get epsilon with decay for better exploration early on epsilon = self.get_epsilon(epoch) # Phase 1: Collect experience mean_return = self.collect_experience( num_steps=self.config.env_steps_per_epoch, epsilon=epsilon, ) metrics["collection_return"] = mean_return # Only update components after warm-start periods if epoch >= self.config.start_autoencoder_after: # Phase 2: Update autoencoder for _ in range(self.config.training_steps_per_epoch): # Sample random frames indices = np.random.randint( 0, len(self.replay_buffer), size=self.config.autoencoder_batch_size ) frames = ( torch.tensor( np.array([self.replay_buffer.observations[i] for i in indices]), dtype=torch.float32, ).to(self.device) / 255.0 ) ae_metrics = self.agent.update_autoencoder(frames) metrics["recon_loss"] = ae_metrics.get("recon_loss", 0) metrics["vq_loss"] = ae_metrics.get("vq_loss", 0) metrics["perplexity"] = ae_metrics.get("perplexity", 0) if ( epoch >= self.config.start_transformer_after and len(self.replay_buffer) >= self.config.transformer_timesteps + 1 ): # Phase 3: Update transformer obs, acts, rews, terms = self.replay_buffer.sample_sequence() obs_tensor = torch.tensor(obs, dtype=torch.float32).to(self.device) / 255.0 acts_tensor = torch.tensor(acts, dtype=torch.float32).to(self.device) rews_tensor = torch.tensor(rews, dtype=torch.float32).to(self.device) terms_tensor = torch.tensor(terms, dtype=torch.long).to(self.device) tf_metrics = self.agent.update_transformer( obs_tensor, acts_tensor, rews_tensor, terms_tensor ) metrics["token_loss"] = tf_metrics.get("token_loss", 0) metrics["reward_loss"] = tf_metrics.get("reward_loss", 0) if ( epoch >= self.config.start_actor_critic_after and len(self.replay_buffer) >= 50 ): # Phase 4: Update actor-critic in imagination # Sample initial frames for imagination sample_size = self.config.actor_critic_batch_size indices = np.random.randint(0, len(self.replay_buffer), size=sample_size) initial_frames = ( torch.tensor( np.array([self.replay_buffer.observations[i] for i in indices]), dtype=torch.float32, ).to(self.device) / 255.0 ) # Generate imagined trajectories imagined = self.agent.imagine_rollout( initial_frame=initial_frames, horizon=self.config.imagination_horizon, ) # Update policy ac_metrics = self.agent.update_actor_critic(imagined) metrics["actor_loss"] = ac_metrics.get("actor_loss", 0) metrics["value_loss"] = ac_metrics.get("value_loss", 0) metrics["entropy"] = ac_metrics.get("entropy", 0) self.agent.current_epoch = epoch self.agent.global_step += self.config.env_steps_per_epoch return metrics
[docs] def get_epsilon(self, epoch: int) -> float: """Get exploration epsilon with decay.""" start_epsilon = 0.5 min_epsilon = self.config.collect_epsilon decay_epochs = 50 epsilon = max( min_epsilon, start_epsilon - (start_epsilon - min_epsilon) * epoch / decay_epochs, ) return epsilon
[docs] def evaluate(self, num_episodes: int = 100, render: bool = False): """Evaluate agent performance. Args: num_episodes: Number of evaluation episodes render: If True, also return video frames and per-step latent vectors Returns: If render is False (default): dict with evaluation metrics If render is True: tuple (episode_returns_array, videos_list, latents_array) """ episode_returns = [] videos: list[list[np.ndarray]] = [] latents_all: list[np.ndarray] = [] for _ in range(num_episodes): raw_obs, _ = self.env.reset() obs = self.preprocess_frame(raw_obs) episode_return = 0 done = False frames: list[np.ndarray] = [] while not done: # Prepare frame for policy (CHW, float32, 0-1) frame_tensor = ( torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device) ) action = self.agent.act( frame_tensor, epsilon=0.0, temperature=self.config.eval_temperature ).item() next_raw, reward, terminated, truncated, _ = self.env.step(action) done = terminated or truncated # Store raw frame for video (as HWC uint8 if possible) try: frames.append(np.asarray(next_raw)) except Exception: # Fallback: convert processed obs back to HWC proc = np.asarray(obs) if proc.ndim == 3: # CHW -> HWC frames.append(proc.transpose(1, 2, 0)) # Compute latent embedding via encoder (quantized embeddings) try: proc_frame = self.preprocess_frame(next_raw) if not done else obs with torch.no_grad(): ft = ( torch.tensor(proc_frame, dtype=torch.float32) .unsqueeze(0) .to(self.device) ) z_q, _, _ = self.agent.encoder(ft) # z_q: (B, C, H', W') -> reduce spatial dims and take mean over channels latent = z_q.mean(dim=(2, 3)).squeeze(0).cpu().numpy() latents_all.append(latent.astype(np.float32)) except Exception: # If encoder fails, skip latent for this step pass episode_return += reward obs = self.preprocess_frame(next_raw) if not done else obs episode_returns.append(episode_return) videos.append(frames) if render: # Stack latents into (N, D) array if any if latents_all: latents_array = np.vstack(latents_all).astype(np.float32) else: latents_array = np.empty((0,), dtype=np.float32) return np.array(episode_returns), videos, latents_array # Non-render fallback: return simple metrics dict for compatibility return { "eval_mean_return": float( np.mean(episode_returns) if episode_returns else 0.0 ), "eval_std_return": float( np.std(episode_returns) if episode_returns else 0.0 ), "eval_max_return": float( np.max(episode_returns) if episode_returns else 0.0 ), "eval_min_return": float( np.min(episode_returns) if episode_returns else 0.0 ), }
[docs] def train( self, total_epochs: int = None, eval_interval: int = 50, save_dir: str = "checkpoints/iris", ): """Full training loop. Args: total_epochs: Total training epochs eval_interval: Evaluate every N epochs save_dir: Directory to save checkpoints """ if total_epochs is None: total_epochs = self.config.total_epochs os.makedirs(save_dir, exist_ok=True) print(f"Starting training for {total_epochs} epochs on {self.game}") print(f"Action space: {self.action_size}") print(f"Device: {self.device}") best_eval_return = float("-inf") for epoch in tqdm(range(total_epochs), desc="Training"): # Train one epoch metrics = self.train_epoch(epoch) # Log metrics for key, value in metrics.items(): self.metrics[key].append(value) # Print progress if epoch % 10 == 0: print(f"\nEpoch {epoch}:") for key, value in metrics.items(): print(f" {key}: {value:.4f}") # Evaluate periodically if ( epoch % eval_interval == 0 and epoch >= self.config.start_actor_critic_after ): eval_metrics = self.evaluate(num_episodes=self.config.eval_episodes) print(f"\nEvaluation at epoch {epoch}:") print( f" Mean return: {eval_metrics['eval_mean_return']:.2f} +/- {eval_metrics['eval_std_return']:.2f}" ) # Save best model if eval_metrics["eval_mean_return"] > best_eval_return: best_eval_return = eval_metrics["eval_mean_return"] save_path = os.path.join( save_dir, f"best_{self.game.split('/')[-1]}.pt" ) self.agent.save(save_path) print(f" Saved best model: {save_path}") # Checkpoint periodically if epoch % self.config.checkpoint_interval == 0: save_path = os.path.join(save_dir, f"checkpoint_{epoch}.pt") self.agent.save(save_path) print(f"\nTraining complete! Best eval return: {best_eval_return:.2f}") return self.metrics
[docs] def main(): """Run IRIS training on a single Atari game.""" import argparse parser = argparse.ArgumentParser(description="Train IRIS on Atari") parser.add_argument("--game", type=str, default="ALE/Pong-v5", help="Atari game") parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") parser.add_argument("--epochs", type=int, default=600, help="Total epochs") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument( "--save_dir", type=str, default="checkpoints/iris", help="Save directory" ) args = parser.parse_args() trainer = IRISTrainer( game=args.game, device=args.device, seed=args.seed, ) trainer.train( total_epochs=args.epochs, save_dir=args.save_dir, )
if __name__ == "__main__": main()