Source code for world_models.training.train_iris

import torch
import numpy as np
from collections import defaultdict
import os
from tqdm import tqdm
import random
from typing import Optional, cast
from gym.spaces import Discrete, Box
import argparse

from types import ModuleType
from typing import Optional as _Optional

# Optional OpenCV import at module scope (avoid function-local imports)
cv2: _Optional[ModuleType] = None
try:
    import cv2 as _cv2

    cv2 = _cv2
except Exception:
    cv2 = None

from world_models.configs.iris_config import IRISConfig
from world_models.models.iris_agent import IRISAgent
from world_models.memory.iris_memory import IRISReplayBuffer
from world_models.envs.ale_atari_env import make_atari_env


[docs] class IRISTrainer: """Training loop for IRIS on Atari 100k benchmark.""" def __init__( self, game: str = "ALE/Pong-v5", device: str = "cuda", seed: int = 42, config: Optional[IRISConfig] = None, ): self.game = game self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.seed = seed # Set seeds random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Config self.config = config if config is not None else IRISConfig() # Create environment self.env = make_atari_env( game, obs_type="rgb", frameskip=4, max_episode_steps=27000, # Standard Atari limit ) # Get action space robustly (Discrete or Box) # Declare attribute type for static checkers self.action_size: int = 0 if isinstance(self.env.action_space, Discrete): self.action_size = int(self.env.action_space.n) elif isinstance(self.env.action_space, Box): shape = getattr(self.env.action_space, "shape", None) if shape is None: raise TypeError("Box action_space has no shape") self.action_size = int(np.prod(tuple(shape))) else: if hasattr(self.env.action_space, "n"): self.action_size = int(getattr(self.env.action_space, "n")) else: raise TypeError( f"Unsupported action_space type: {type(self.env.action_space)}" ) # Create replay buffer self.replay_buffer = IRISReplayBuffer( size=100000, # 100k buffer obs_shape=(3, 64, 64), # Resize frames to 64x64 action_size=self.action_size, seq_len=self.config.transformer_timesteps, batch_size=self.config.transformer_batch_size, ) # Create agent self.agent = IRISAgent( config=self.config, action_size=self.action_size, device=self.device, ) # Metrics: map metric name -> series of numeric values self.metrics: defaultdict[str, list[float]] = defaultdict(list)
[docs] def preprocess_frame(self, frame: np.ndarray) -> np.ndarray: """Preprocess frame: resize to 64x64 and normalize.""" if cv2 is None: raise ImportError( "cv2 is required for frame preprocessing. Install opencv-python" ) frame = cv2.resize(frame, (64, 64), interpolation=cv2.INTER_LINEAR) frame = frame.astype(np.float32) / 255.0 return frame.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
[docs] def collect_experience( self, num_steps: int, epsilon: float = 0.01, ) -> float: """Collect experience from environment. Args: num_steps: Number of steps to collect epsilon: Random action probability Returns: Mean episode return """ obs, _ = self.env.reset() obs = self.preprocess_frame(obs) episode_returns = [] current_return: float = 0.0 steps_in_episode = 0 for step in range(num_steps): # Choose action if np.random.random() < epsilon: action = self.env.action_space.sample() else: frame_tensor = ( torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device) ) action = self.agent.act(frame_tensor, epsilon=0.0).item() # Step environment next_obs, reward, terminated, truncated, info = self.env.step(action) done = terminated or truncated next_obs = self.preprocess_frame(next_obs) # Store in replay buffer action_one_hot: np.ndarray = np.zeros(self.action_size, dtype=np.float32) action_one_hot[action] = 1.0 self.replay_buffer.add( obs, action_one_hot, float(reward), done, ) current_return += float(reward) steps_in_episode += 1 if done: episode_returns.append(current_return) current_return = 0.0 steps_in_episode = 0 obs, _ = self.env.reset() obs = self.preprocess_frame(obs) else: obs = next_obs return float(np.mean(episode_returns)) if episode_returns else 0.0
[docs] def train_epoch(self, epoch: int) -> dict: """Train for one epoch. Args: epoch: Current epoch number Returns: Dictionary of metrics """ metrics = {} # Get epsilon with decay for better exploration early on epsilon = self.get_epsilon(epoch) # Phase 1: Collect experience mean_return = self.collect_experience( num_steps=self.config.env_steps_per_epoch, epsilon=epsilon, ) metrics["collection_return"] = mean_return # Only update components after warm-start periods if epoch >= self.config.start_autoencoder_after: # Phase 2: Update autoencoder ae_metrics: dict = {} for _ in range(self.config.training_steps_per_epoch): # Sample random frames indices = np.random.randint( 0, len(self.replay_buffer), size=self.config.autoencoder_batch_size ) frames = ( torch.tensor( np.array([self.replay_buffer.observations[i] for i in indices]), dtype=torch.float32, ).to(self.device) / 255.0 ) ae_metrics = self.agent.update_autoencoder(frames) metrics["recon_loss"] = ae_metrics.get("recon_loss", 0) metrics["vq_loss"] = ae_metrics.get("vq_loss", 0) metrics["perplexity"] = ae_metrics.get("perplexity", 0) if ( epoch >= self.config.start_transformer_after and len(self.replay_buffer) >= self.config.transformer_timesteps + 1 ): # Phase 3: Update transformer obs, acts, rews, terms = self.replay_buffer.sample_sequence() obs_tensor = torch.tensor(obs, dtype=torch.float32).to(self.device) / 255.0 acts_tensor = torch.tensor(acts, dtype=torch.float32).to(self.device) rews_tensor = torch.tensor(rews, dtype=torch.float32).to(self.device) terms_tensor = torch.tensor(terms, dtype=torch.long).to(self.device) tf_metrics = self.agent.update_transformer( obs_tensor, acts_tensor, rews_tensor, terms_tensor ) metrics["token_loss"] = tf_metrics.get("token_loss", 0) metrics["reward_loss"] = tf_metrics.get("reward_loss", 0) if ( epoch >= self.config.start_actor_critic_after and len(self.replay_buffer) >= 50 ): # Phase 4: Update actor-critic in imagination # Sample initial frames for imagination sample_size = self.config.actor_critic_batch_size indices = np.random.randint(0, len(self.replay_buffer), size=sample_size) initial_frames = ( torch.tensor( np.array([self.replay_buffer.observations[i] for i in indices]), dtype=torch.float32, ).to(self.device) / 255.0 ) # Generate imagined trajectories imagined = self.agent.imagine_rollout( initial_frame=initial_frames, horizon=self.config.imagination_horizon, ) # Update policy ac_metrics = self.agent.update_actor_critic(imagined) metrics["actor_loss"] = ac_metrics.get("actor_loss", 0) metrics["value_loss"] = ac_metrics.get("value_loss", 0) metrics["entropy"] = ac_metrics.get("entropy", 0) self.agent.current_epoch = epoch self.agent.global_step += self.config.env_steps_per_epoch return metrics
[docs] def get_epsilon(self, epoch: int) -> float: """Get exploration epsilon with decay.""" start_epsilon = 0.5 min_epsilon = self.config.collect_epsilon decay_epochs = 50 epsilon = max( min_epsilon, start_epsilon - (start_epsilon - min_epsilon) * epoch / decay_epochs, ) return epsilon
[docs] def evaluate(self, num_episodes: int = 100, render: bool = False): """Evaluate agent performance. Args: num_episodes: Number of evaluation episodes render: If True, also return video frames and per-step latent vectors Returns: If render is False (default): dict with evaluation metrics If render is True: tuple (episode_returns_array, videos_list, latents_array) """ episode_returns = [] videos: list[list[np.ndarray]] = [] latents_all: list[np.ndarray] = [] for _ in range(num_episodes): raw_obs, _ = self.env.reset() obs = self.preprocess_frame(raw_obs) episode_return: float = 0.0 done = False frames: list[np.ndarray] = [] while not done: # Prepare frame for policy (CHW, float32, 0-1) frame_tensor = ( torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self.device) ) action = self.agent.act( frame_tensor, epsilon=0.0, temperature=self.config.eval_temperature ).item() next_raw, reward, terminated, truncated, _ = self.env.step(action) done = terminated or truncated # Store raw frame for video (as HWC uint8 if possible) try: frames.append(np.asarray(next_raw)) except Exception: # Fallback: convert processed obs back to HWC proc = np.asarray(obs) if proc.ndim == 3: # CHW -> HWC frames.append(proc.transpose(1, 2, 0)) # Compute latent embedding via encoder (quantized embeddings) try: proc_frame = self.preprocess_frame(next_raw) if not done else obs with torch.no_grad(): ft = ( torch.tensor(proc_frame, dtype=torch.float32) .unsqueeze(0) .to(self.device) ) z_q, _, _ = self.agent.encoder(ft) # z_q: (B, C, H', W') -> reduce spatial dims and take mean over channels latent = z_q.mean(dim=(2, 3)).squeeze(0).cpu().numpy() latents_all.append(latent.astype(np.float32)) except Exception: # If encoder fails, skip latent for this step pass episode_return += float(reward) obs = self.preprocess_frame(next_raw) if not done else obs episode_returns.append(episode_return) videos.append(frames) if render: # Stack latents into (N, D) array if any if latents_all: latents_array = np.vstack(latents_all).astype(np.float32) else: latents_array = np.empty((0,), dtype=np.float32) return np.array(episode_returns), videos, latents_array # Non-render fallback: return simple metrics dict for compatibility return { "eval_mean_return": float( np.mean(episode_returns) if episode_returns else 0.0 ), "eval_std_return": float( np.std(episode_returns) if episode_returns else 0.0 ), "eval_max_return": float( np.max(episode_returns) if episode_returns else 0.0 ), "eval_min_return": float( np.min(episode_returns) if episode_returns else 0.0 ), }
[docs] def train( self, total_epochs: Optional[int] = None, eval_interval: int = 50, save_dir: str = "checkpoints/iris", ): """Full training loop. Args: total_epochs: Total training epochs eval_interval: Evaluate every N epochs save_dir: Directory to save checkpoints """ if total_epochs is None: total_epochs = self.config.total_epochs os.makedirs(save_dir, exist_ok=True) print(f"Starting training for {total_epochs} epochs on {self.game}") print(f"Action space: {self.action_size}") print(f"Device: {self.device}") best_eval_return = float("-inf") for epoch in tqdm(range(total_epochs), desc="Training"): # Train one epoch metrics = self.train_epoch(epoch) # Log metrics for key, value in metrics.items(): self.metrics[key].append(value) # Print progress if epoch % 10 == 0: print(f"\nEpoch {epoch}:") for key, value in metrics.items(): print(f" {key}: {value:.4f}") # Evaluate periodically if ( epoch % eval_interval == 0 and epoch >= self.config.start_actor_critic_after ): eval_metrics = cast( dict, self.evaluate(num_episodes=self.config.eval_episodes) ) print(f"\nEvaluation at epoch {epoch}:") print( f" Mean return: {eval_metrics['eval_mean_return']:.2f} +/- {eval_metrics['eval_std_return']:.2f}" ) # Save best model if eval_metrics["eval_mean_return"] > best_eval_return: best_eval_return = eval_metrics["eval_mean_return"] save_path = os.path.join( save_dir, f"best_{self.game.split('/')[-1]}.pt" ) self.agent.save(save_path) print(f" Saved best model: {save_path}") # Checkpoint periodically if epoch % self.config.checkpoint_interval == 0: save_path = os.path.join(save_dir, f"checkpoint_{epoch}.pt") self.agent.save(save_path) print(f"\nTraining complete! Best eval return: {best_eval_return:.2f}") return self.metrics
[docs] def main(): """Run IRIS training on a single Atari game.""" parser = argparse.ArgumentParser(description="Train IRIS on Atari") parser.add_argument("--game", type=str, default="ALE/Pong-v5", help="Atari game") parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") parser.add_argument("--epochs", type=int, default=600, help="Total epochs") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument( "--save_dir", type=str, default="checkpoints/iris", help="Save directory" ) args = parser.parse_args() trainer = IRISTrainer( game=args.game, device=args.device, seed=args.seed, ) trainer.train( total_epochs=args.epochs, save_dir=args.save_dir, )
if __name__ == "__main__": main()