Source code for world_models.training.rl_harness

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from typing import Dict, Any, List
import logging

from world_models.envs.vector_env import TorchVectorizedEnv

logger = logging.getLogger(__name__)


[docs] class ActorCritic(nn.Module): """Simple actor-critic network for RL harness.""" def __init__(self, obs_shape: tuple, action_dim: int, hidden_dim: int = 256): super().__init__() self.obs_shape = obs_shape self.action_dim = action_dim # CNN for image observations self.cnn = nn.Sequential( nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(), nn.Flatten(), ) # Compute CNN output size with torch.no_grad(): dummy = torch.zeros(1, *obs_shape) cnn_out = self.cnn(dummy).shape[1] self.actor = nn.Sequential( nn.Linear(cnn_out, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim), ) self.critic = nn.Sequential( nn.Linear(cnn_out, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), )
[docs] def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Forward pass through CNN, then actor and critic heads.""" features = self.cnn(obs) action_logits = self.actor(features) value = self.critic(features) return action_logits, value
[docs] def get_action( self, obs: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Sample action from policy.""" logits, value = self(obs) dist = Categorical(logits=logits) action = dist.sample() log_prob = dist.log_prob(action) return action, log_prob, value
[docs] class PPOTrainer: """Simple PPO trainer for testing vectorized environments.""" def __init__( self, vec_env: TorchVectorizedEnv, device: str = "cpu", lr: float = 3e-4, gamma: float = 0.99, gae_lambda: float = 0.95, clip_ratio: float = 0.2, num_epochs: int = 10, batch_size: int = 64, max_grad_norm: float = 0.5, entropy_coeff: float = 0.01, value_coeff: float = 0.5, ): self.vec_env = vec_env self.device = device self.gamma = gamma self.gae_lambda = gae_lambda self.clip_ratio = clip_ratio self.num_epochs = num_epochs self.batch_size = batch_size self.max_grad_norm = max_grad_norm self.entropy_coeff = entropy_coeff self.value_coeff = value_coeff # Get action dim from env if hasattr(vec_env.action_space, "n"): action_dim = vec_env.action_space.n else: action_dim = vec_env.action_space.shape[0] # Assume image obs for now obs_shape = vec_env.observation_space["image"].shape self.policy = ActorCritic(obs_shape, action_dim).to(device) self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
[docs] def collect_trajectories(self, num_steps: int) -> Dict[str, torch.Tensor]: """Collect trajectories using the vectorized environment.""" obs_batch: Any = self.vec_env.reset_batch() obs = obs_batch["obs"]["image"].to(self.device) # Collect lists of tensors, convert to stacked tensors before returning trajectories: Dict[str, List[torch.Tensor]] = { "obs": [], "actions": [], "log_probs": [], "rewards": [], "values": [], "dones": [], } for _ in range(num_steps // self.vec_env.total_envs): with torch.no_grad(): actions, log_probs, values = self.policy.get_action(obs) # Step environment actions_np = actions.cpu().numpy() step_result = self.vec_env.step_batch( torch.from_numpy(actions_np).to(self.device) ) next_obs = step_result["obs"]["image"].to(self.device) rewards = step_result["reward"].to(self.device) dones = step_result["done"].to(self.device) # Store trajectory data trajectories["obs"].append(obs) trajectories["actions"].append(actions) trajectories["log_probs"].append(log_probs) trajectories["rewards"].append(rewards) trajectories["values"].append(values.squeeze(-1)) trajectories["dones"].append(dones) obs = next_obs # Convert lists to stacked tensors for return result: Dict[str, torch.Tensor] = {} for key, lst in trajectories.items(): result[key] = torch.stack(lst) return result
[docs] def compute_gae( self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor ) -> torch.Tensor: """Compute Generalized Advantage Estimation.""" advantages = torch.zeros_like(rewards) # ensure last_gae is a tensor matching per-step value shape # values[t] has the same shape as rewards[t] last_gae = torch.zeros_like(values[0]) for t in reversed(range(len(rewards))): if t == len(rewards) - 1: next_value = torch.zeros_like(values[t]) else: next_value = values[t + 1] delta = ( rewards[t] + self.gamma * next_value * (1 - dones[t].float()) - values[t] ) last_gae = ( delta + self.gamma * self.gae_lambda * (1 - dones[t].float()) * last_gae ) advantages[t] = last_gae return advantages
[docs] def train_step(self, trajectories: Dict[str, torch.Tensor]): """Perform one training step using PPO.""" obs = trajectories["obs"].view( -1, *self.vec_env.observation_space["image"].shape ) actions = trajectories["actions"].view(-1) old_log_probs = trajectories["log_probs"].view(-1) rewards = trajectories["rewards"].view(-1) values = trajectories["values"].view(-1) dones = trajectories["dones"].view(-1) # Compute advantages advantages = self.compute_gae(rewards, values, dones) returns = advantages + values # Normalize advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # PPO update for _ in range(self.num_epochs): indices = torch.randperm(len(obs)) for start in range(0, len(obs), self.batch_size): end = start + self.batch_size batch_indices = indices[start:end] batch_obs = obs[batch_indices] batch_actions = actions[batch_indices] batch_old_log_probs = old_log_probs[batch_indices] batch_advantages = advantages[batch_indices] batch_returns = returns[batch_indices] # Forward pass new_logits, new_values = self.policy(batch_obs) new_values = new_values.squeeze(-1) # Policy loss dist = Categorical(logits=new_logits) new_log_probs = dist.log_prob(batch_actions) ratio = torch.exp(new_log_probs - batch_old_log_probs) surr1 = ratio * batch_advantages surr2 = ( torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages ) policy_loss = -torch.min(surr1, surr2).mean() # Value loss value_loss = F.mse_loss(new_values, batch_returns) # Entropy bonus entropy = dist.entropy().mean() # Total loss loss = ( policy_loss + self.value_coeff * value_loss - self.entropy_coeff * entropy ) # Update self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_( self.policy.parameters(), self.max_grad_norm ) self.optimizer.step()
[docs] def train(self, total_timesteps: int, log_interval: int = 1000): """Main training loop.""" logger.info(f"Starting training for {total_timesteps} timesteps") timestep = 0 while timestep < total_timesteps: # Collect trajectories trajectories = self.collect_trajectories( min(2048, total_timesteps - timestep) ) # trajectories["obs"] is a tensor; ensure int arithmetic for timestep timestep += int(len(trajectories["obs"])) * self.vec_env.total_envs # Train self.train_step(trajectories) if timestep % log_interval == 0: logger.info(f"Timestep {timestep}: Training step completed")
[docs] def create_rl_harness_example(): """ Example function to create and run the RL harness. Usage: Call this with your environment factory. """ from world_models.envs.gym_env import make_gym_env def env_factory(): return make_gym_env("CartPole-v1", size=(64, 64)) vec_env = TorchVectorizedEnv( env_factory=env_factory, num_workers=2, envs_per_worker=4, seed=42 ) # Create trainer trainer = PPOTrainer(vec_env, device="cpu") # Train trainer.train(total_timesteps=10000) # Cleanup vec_env.close()
if __name__ == "__main__": create_rl_harness_example()