Source code for world_models.training.train_diamond

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import random
from typing import Dict, List, Optional, Tuple, Union
from tqdm import tqdm
import os
from pathlib import Path
import argparse

from world_models.configs.diamond_config import (
    DiamondConfig,
    HUMAN_SCORES,
    RANDOM_SCORES,
)
from world_models.envs.diamond_atari import make_diamond_atari_env
from gym.spaces import Discrete, Box
from world_models.datasets.diamond_dataset import ReplayBuffer, SequenceDataset
from world_models.models.diffusion.diamond_diffusion import (
    DiffusionUNet,
    EDMPreconditioner,
    EulerSampler,
)
from world_models.models.diffusion.reward_termination import (
    RewardTerminationModel,
    RewardTerminationLoss,
)
from world_models.models.diffusion.actor_critic import (
    ActorCriticNetwork,
    RLLoss,
)


[docs] class DiamondAgent: """ DIAMOND: DIffusion As a Model Of eNvironment Dreams RL agent trained entirely within a diffusion world model. """ def __init__(self, config: DiamondConfig): self.config = config self.device = torch.device( config.device if torch.cuda.is_available() else "cpu" ) self.env = make_diamond_atari_env( game=config.game, frameskip=config.frameskip, max_noop=config.max_noop, terminate_on_life_loss=config.terminate_on_life_loss, reward_clip=True, resize=(config.obs_size, config.obs_size), seed=config.seed, ) # action_space may be different Space types; prefer Discrete.n when available # Declare attribute type for static checkers self.action_dim: int = 0 if isinstance(self.env.action_space, Discrete): self.action_dim = int(self.env.action_space.n) elif isinstance(self.env.action_space, Box): # continuous action space -> flatten dimensions shape = getattr(self.env.action_space, "shape", None) if shape is None: raise TypeError("Box action_space has no shape") # ensure shape is numeric sequence before taking product self.action_dim = int(np.prod(tuple(shape))) else: # fallback: try to read 'n' attribute, else raise informative error if hasattr(self.env.action_space, "n"): self.action_dim = int(getattr(self.env.action_space, "n")) else: raise TypeError( f"Unsupported action_space type: {type(self.env.action_space)}" ) self._build_models() self.replay_buffer = ReplayBuffer( capacity=100000, obs_shape=(config.obs_size, config.obs_size, 3), action_dim=1, device=config.device, ) self.obs_history: List[np.ndarray] = [] # keep a raw-uint8 history in parallel with the normalized history self.obs_history_raw: List[np.ndarray] = [] self.action_history: List[int] = [] self.total_steps = 0 self.global_step = 0 # last LSTM hidden states (saved for reproducible imagined rollouts) self.last_policy_hidden: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.last_reward_hidden: Optional[Tuple[torch.Tensor, torch.Tensor]] = None def _build_models(self): """Initialize all DIAMOND models.""" self.diffusion_model = DiffusionUNet( obs_channels=3, num_conditioning_frames=self.config.num_conditioning_frames, # config.diffusion_channels is a list of absolute channel sizes per level # DiffusionUNet expects base_channels and channel_multipliers (multipliers # relative to base). Convert absolute sizes to multipliers here. base_channels=self.config.diffusion_channels[0], channel_multipliers=tuple( [ int(c // self.config.diffusion_channels[0]) for c in self.config.diffusion_channels ] ), num_res_blocks=self.config.diffusion_res_blocks, cond_dim=self.config.diffusion_cond_dim, action_dim=self.action_dim, ).to(self.device) self.edm_precond = EDMPreconditioner( sigma_data=self.config.sigma_data, p_mean=self.config.p_mean, p_std=self.config.p_std, ) self.sampler = EulerSampler( sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, rho=self.config.rho, num_steps=self.config.num_sampling_steps, edm_precond=self.edm_precond, ) self.reward_model = RewardTerminationModel( obs_channels=3, action_dim=self.action_dim, channels=tuple(self.config.reward_channels), lstm_dim=self.config.reward_lstm_dim, cond_dim=self.config.reward_cond_dim, ).to(self.device) self.reward_loss_fn = RewardTerminationLoss() self.actor_critic = ActorCriticNetwork( obs_channels=3, action_dim=self.action_dim, channels=tuple(self.config.actor_channels), lstm_dim=self.config.actor_lstm_dim, ).to(self.device) self.rl_loss_fn = RLLoss( discount_factor=self.config.discount_factor, lambda_returns=self.config.lambda_returns, entropy_weight=self.config.entropy_weight, ) self.diffusion_opt = optim.AdamW( self.diffusion_model.parameters(), lr=self.config.learning_rate, eps=self.config.adam_epsilon, weight_decay=self.config.weight_decay_diffusion, ) self.reward_opt = optim.AdamW( self.reward_model.parameters(), lr=self.config.learning_rate, eps=self.config.adam_epsilon, weight_decay=self.config.weight_decay_reward, ) self.actor_opt = optim.AdamW( self.actor_critic.parameters(), lr=self.config.learning_rate, eps=self.config.adam_epsilon, weight_decay=self.config.weight_decay_actor, ) def _update_diffusion_model(self, batch: Dict[str, torch.Tensor]) -> float: """Update diffusion world model.""" self.diffusion_model.train() obs_seq = batch["obs_seq"] action_seq = batch["action_seq"] next_obs = batch["next_obs"] B, T, C, H, W = obs_seq.shape # use only the last `num_conditioning_frames` for conditioning obs_history = obs_seq[:, -self.config.num_conditioning_frames :] target_obs = next_obs sigma = self.edm_precond.sample_noise_level(B, self.device) sigma = sigma.view(B, 1, 1, 1) noise = torch.randn_like(target_obs) noisy_target = target_obs + sigma * noise precond = self.edm_precond.get_preconditioners(sigma) model_input = precond["c_in"] * noisy_target # use c_noise (log-sigma transform) for time conditioning as in EDM t_cond = precond["c_noise"].squeeze(-1).squeeze(-1) # Debug asserts: ensure shapes are as expected try: # obs_seq: [B, T, C, H, W], obs_history: [B, L, C, H, W], next_obs [B, C, H, W] assert obs_seq.ndim == 5 assert obs_history.ndim == 5 assert target_obs.ndim == 4 except AssertionError: print( f"DEBUG SHAPES: obs_seq={getattr(obs_seq, 'shape', None)}, obs_history={getattr(obs_history, 'shape', None)}, target_obs={getattr(target_obs, 'shape', None)}" ) model_output = self.diffusion_model( x=model_input, t=t_cond, obs_history=obs_history, actions=action_seq[:, -self.config.num_conditioning_frames :], ) target = (next_obs - precond["c_skip"] * noisy_target) / precond["c_out"] loss = F.mse_loss(model_output, target) self.diffusion_opt.zero_grad() loss.backward() self.diffusion_opt.step() return loss.item() def _update_reward_model(self, batch: Dict[str, torch.Tensor]) -> float: """Update reward/termination model.""" self.reward_model.train() obs_seq = batch["obs_seq"] action_seq = batch["action_seq"] rewards = batch["rewards"] dones = batch["dones"] B, T, C, H, W = obs_seq.shape reward_logits, term_logits, _ = self.reward_model( obs=obs_seq, actions=action_seq, ) # Align target sequence lengths: reward_logits has same temporal length as obs_seq # We predict next-step rewards for all but the last conditioning frame. # Align rewards/dones length with reward_logits[:, :-1] T_logits = reward_logits.shape[1] target_len = max(0, T_logits - 1) rewards_target = rewards[:, :target_len] dones_target = dones[:, :target_len] total_loss, reward_loss, term_loss = self.reward_loss_fn( reward_logits=reward_logits[:, :-1], termination_logits=term_logits[:, :-1], rewards=rewards_target, terminated=dones_target, ) self.reward_opt.zero_grad() total_loss.backward() self.reward_opt.step() return total_loss.item() def _update_actor_critic( self, batch: Dict[str, torch.Tensor] ) -> Tuple[float, float]: """Update actor-critic using imagined rollouts. This replaces using real dataset trajectories for policy/value updates and follows the paper: compute policy and value losses on imagined trajectories produced by the diffusion world model. """ self.actor_critic.train() obs_seq = batch["obs_seq"] action_seq = batch.get("action_seq", batch.get("actions")) B, seq_T, C, H, W = obs_seq.shape # Determine burn-in / conditioning length and horizon burn_in = self.config.burn_in_length horizon = self.config.imagination_horizon # safety checks assert seq_T >= burn_in, "Sequence shorter than burn-in" # initial conditioning sequences for imagination obs_history = obs_seq[:, :burn_in] if action_seq is None: action_history = torch.zeros( (B, burn_in), dtype=torch.long, device=self.device ) else: action_history = action_seq[:, :burn_in] # init reward model hidden state for batched imagination. Prefer a # restored last_reward_hidden when available (broadcast if needed). reward_hidden = None if self.last_reward_hidden is not None: try: h_saved, c_saved = self.last_reward_hidden saved_B = h_saved.shape[1] if saved_B == B: reward_hidden = (h_saved.to(self.device), c_saved.to(self.device)) elif saved_B == 1: reward_hidden = ( h_saved.to(self.device).repeat(1, B, 1), c_saved.to(self.device).repeat(1, B, 1), ) else: reward_hidden = None except Exception: reward_hidden = None if reward_hidden is None: reward_hidden = self.reward_model.init_hidden(B, self.device) # prime the policy LSTM with the burn-in observations to obtain a # proper initial hidden state for imagined rollouts. If a saved # last_policy_hidden exists (restored from checkpoint), prefer it so # imagined rollouts can be exactly reproduced across runs. If the # saved hidden state's batch dimension is 1, broadcast it to current # batch size. policy_hidden_init = None if self.last_policy_hidden is not None: try: h_saved, c_saved = self.last_policy_hidden saved_B = h_saved.shape[1] if saved_B == B: policy_hidden_init = ( h_saved.to(self.device), c_saved.to(self.device), ) elif saved_B == 1: policy_hidden_init = ( h_saved.to(self.device).repeat(1, B, 1), c_saved.to(self.device).repeat(1, B, 1), ) else: policy_hidden_init = None except Exception: policy_hidden_init = None if policy_hidden_init is None: with torch.no_grad(): _, _, policy_hidden_init = self.actor_critic(obs_history) # store a CPU copy so checkpoints are device independent try: self.last_policy_hidden = ( policy_hidden_init[0].detach().cpu(), policy_hidden_init[1].detach().cpu(), ) except Exception: self.last_policy_hidden = None # Imagine a trajectory and obtain policy actions taken during imagination ( obs_imag, rewards_imag, dones_imag, policy_actions_imag, reward_hidden, ) = self._imagine_trajectory( obs_history, action_history, reward_hidden, policy_hidden=policy_hidden_init ) # store the final reward hidden state (CPU) for checkpointing / replay try: self.last_reward_hidden = ( reward_hidden[0].detach().cpu(), reward_hidden[1].detach().cpu(), ) except Exception: self.last_reward_hidden = None # Compute policy logits and values for imagined observations using the # same initial policy hidden state used during imagination so the # logits align with the sampled actions. policy_logits, values, _ = self.actor_critic(obs_imag, policy_hidden_init) # values: [B, H, 1] -> squeeze to [B, H] values_squeezed = values.squeeze(-1) # append bootstrap last value to make [B, H+1] values_with_bootstrap = torch.cat( [values_squeezed, values_squeezed[:, -1:].detach()], dim=1 ) # lambda-returns on imagined rewards lambda_returns = self.rl_loss_fn.compute_lambda_returns( rewards=rewards_imag, values=values_with_bootstrap, dones=dones_imag, ) policy_loss = self.rl_loss_fn.policy_loss( policy_logits=policy_logits, actions=policy_actions_imag, lambda_returns=lambda_returns, values=values_with_bootstrap, ) value_loss = self.rl_loss_fn.value_loss( values=values_with_bootstrap.unsqueeze(-1), lambda_returns=lambda_returns, ) total_loss = policy_loss + value_loss self.actor_opt.zero_grad() total_loss.backward() self.actor_opt.step() return policy_loss.item(), value_loss.item() def _collect_experience(self, num_steps: int) -> List[float]: """Collect experience from the real environment.""" rewards = [] if len(self.obs_history) == 0: raw_obs, _ = self.env.reset() norm_obs = raw_obs.astype(np.float32) / 255.0 # maintain both normalized and raw histories self.obs_history = [norm_obs] * self.config.num_conditioning_frames self.obs_history_raw = [raw_obs] * self.config.num_conditioning_frames for _ in range(num_steps): # build tensor [1, L, C, H, W] with channels-first obs_np = np.stack(self.obs_history[-self.config.num_conditioning_frames :]) # obs_np: [L, H, W, C] -> transpose to [L, C, H, W] obs_np = obs_np.transpose(0, 3, 1, 2) obs_tensor = torch.from_numpy(obs_np).unsqueeze(0).to(self.device) # pass a batched single observation [1, C, H, W] action, _ = self.actor_critic.get_action( obs_tensor[:, -1], None, deterministic=False, ) if random.random() < self.config.epsilon_greedy: action = self.env.action_space.sample() next_obs, reward, done, _ = self.env.step(action) # env typically returns uint8 frames; keep raw and normalized next_obs_raw = next_obs next_obs = next_obs_raw.astype(np.float32) / 255.0 # store raw uint8 frames in the replay buffer (avoid lossy casts) self.replay_buffer.add( obs=self.obs_history_raw[-1], action=action, reward=reward, done=done, next_obs=next_obs_raw, ) rewards.append(reward) # update both normalized and raw histories self.obs_history.append(next_obs) self.obs_history_raw.append(next_obs_raw) self.action_history.append(action) if done: raw_obs, _ = self.env.reset() norm_obs = raw_obs.astype(np.float32) / 255.0 self.obs_history = [norm_obs] * self.config.num_conditioning_frames self.obs_history_raw = [raw_obs] * self.config.num_conditioning_frames self.action_history = [] return rewards @torch.no_grad() def _imagine_trajectory( self, obs_history: torch.Tensor, action_history: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor], policy_hidden: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], ]: """ Imagine a trajectory using the diffusion world model. Args: obs_history: Initial observations [B, L, C, H, W] action_history: Initial actions [B, L] hidden_state: Initial LSTM hidden state Returns: obs_trajectory: [B, H, C, H, W] rewards: [B, H] dones: [B, H] hidden_state: Updated hidden state """ B = obs_history.shape[0] horizon = self.config.imagination_horizon obs_trajectory = [] rewards_list = [] dones_list = [] obs_current = obs_history actions_current = action_history # initialize a policy hidden state for batched policy sampling during imagination # allow caller to provide an initial policy hidden state (primed by # burn-in sequence); otherwise initialize fresh hidden state if policy_hidden is None: policy_hidden = self.actor_critic.init_hidden(B, self.device) policy_actions_list = [] for t in range(horizon): # sampler returns [B, C, H, W] sampled = self.sampler.sample( model=self.diffusion_model, shape=(B, 3, self.config.obs_size, self.config.obs_size), device=self.device, obs_history=obs_current, actions=actions_current, ) # predict reward/termination from the sampled frame [B, C, H, W] reward, done, hidden_state = self.reward_model.predict( obs=sampled, actions=actions_current[:, -1], hidden_state=hidden_state, ) # append squeezed frame [B, C, H, W] for stacking later obs_trajectory.append(sampled) rewards_list.append(reward) dones_list.append(done) # update conditioning sequences: obs_current expects [B, L, C, H, W] next_obs_seq = sampled.unsqueeze(1) obs_current = torch.cat([obs_current[:, 1:], next_obs_seq], dim=1) # Batch-query the policy for the next actions for all samples at once # Maintain and update the policy LSTM hidden state across imagined # timesteps so the policy can condition on the imagined trajectory. policy_actions, policy_hidden = self.actor_critic.get_actions( sampled, policy_hidden, deterministic=False ) # collect the actions (B,) per timestep policy_actions_list.append(policy_actions) # policy_actions: [B] -> make [B, 1] so it can be concatenated policy_actions = policy_actions.unsqueeze(-1) actions_current = torch.cat([actions_current[:, 1:], policy_actions], dim=1) return ( torch.stack(obs_trajectory, dim=1), torch.stack(rewards_list, dim=1), torch.stack(dones_list, dim=1), torch.stack(policy_actions_list, dim=1), hidden_state, )
[docs] def train(self): """Main training loop following Algorithm 1.""" print(f"Training DIAMOND on {self.config.game}") print(f"Device: {self.device}") print(f"Action space: {self.action_dim}") for epoch in tqdm(range(self.config.num_epochs), desc="Training"): collected_rewards = self._collect_experience( self.config.environment_steps_per_epoch ) if not self.replay_buffer.is_ready(self.config.batch_size): continue dataset = SequenceDataset( replay_buffer=self.replay_buffer, sequence_length=self.config.burn_in_length + self.config.imagination_horizon, burn_in=self.config.burn_in_length, ) dataloader = DataLoader( dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=0, ) diffusion_losses = [] reward_losses = [] policy_losses = [] value_losses = [] # iterate over the dataloader properly; avoid recreating iterator each step data_iter = iter(dataloader) for _ in range(self.config.training_steps_per_epoch): try: batch = next(data_iter) except StopIteration: data_iter = iter(dataloader) batch = next(data_iter) diffusion_loss = self._update_diffusion_model(batch) diffusion_losses.append(diffusion_loss) reward_loss = self._update_reward_model(batch) reward_losses.append(reward_loss) policy_loss, value_loss = self._update_actor_critic(batch) policy_losses.append(policy_loss) value_losses.append(value_loss) if epoch % self.config.log_interval == 0: print(f"\nEpoch {epoch}:") print(f" Diffusion loss: {np.mean(diffusion_losses):.4f}") print(f" Reward loss: {np.mean(reward_losses):.4f}") print(f" Policy loss: {np.mean(policy_losses):.4f}") print(f" Value loss: {np.mean(value_losses):.4f}") print(f" Collected reward: {np.mean(collected_rewards):.2f}") if epoch % self.config.eval_interval == 0: eval_reward = self.evaluate() hns = self._compute_human_normalized_score(eval_reward) print(f" Eval reward: {eval_reward:.2f}, HNS: {hns:.3f}") if epoch % self.config.save_interval == 0: self.save_checkpoint(f"checkpoint_{epoch}.pt")
[docs] @torch.no_grad() def evaluate(self, num_episodes: int = 1) -> float: """Evaluate the agent.""" self.actor_critic.eval() self.diffusion_model.eval() self.reward_model.eval() total_reward = 0.0 for _ in range(num_episodes): obs, _ = self.env.reset() obs = obs.astype(np.float32) / 255.0 obs_history = [obs] * self.config.num_conditioning_frames # initialize separate hidden states for reward model and policy reward_hidden = self.reward_model.init_hidden(1, self.device) policy_hidden = self.actor_critic.init_hidden(1, self.device) done = False episode_reward = 0.0 while not done: obs_np = np.stack(obs_history[-self.config.num_conditioning_frames :]) obs_np = obs_np.transpose(0, 3, 1, 2) obs_tensor = torch.from_numpy(obs_np).unsqueeze(0).to(self.device) # pass batched observation [1, C, H, W] action, policy_hidden = self.actor_critic.get_action( obs_tensor[:, -1], policy_hidden, deterministic=True, ) next_obs, reward, done, _ = self.env.step(action) next_obs = next_obs.astype(np.float32) / 255.0 episode_reward += reward obs_history.append(next_obs) total_reward += episode_reward return total_reward / num_episodes
def _compute_human_normalized_score(self, score: float) -> float: """Compute human-normalized score.""" game = self.config.game human = HUMAN_SCORES.get(game, 1.0) random = RANDOM_SCORES.get(game, 0.0) if human == random: return 0.0 return (score - random) / (human - random)
[docs] def save_checkpoint(self, path: Optional[Union[str, os.PathLike]] = None): """Save model checkpoint. Args: path: Optional path where to write the checkpoint. If `path` is None or a bare filename, the file is written into `checkpoints/diamond/<filename>`. If `path` contains a directory component or is an absolute/relative path, it is used directly. When `path` is None, the legacy behavior is preserved and the checkpoint is written to `checkpoints/diamond/checkpoint.pt`. """ # Determine output path. Preserve existing behaviour when path is None # or a bare filename by writing into checkpoints/diamond. default_dir = Path("checkpoints/diamond") if path is None: default_dir.mkdir(parents=True, exist_ok=True) out_path = default_dir / "checkpoint.pt" else: pathp = Path(path) if pathp.parent != Path(""): pathp.parent.mkdir(parents=True, exist_ok=True) out_path = pathp else: default_dir.mkdir(parents=True, exist_ok=True) out_path = default_dir / pathp # Trim replay buffer arrays to current size to avoid saving full-capacity # arrays which are wasteful for checkpoints and can blow up memory. # Trim and persist replay buffer arrays to separate numpy file(s) to # avoid saving large Python objects inside the torch checkpoint. This # reduces checkpoint size and avoids unsafe pickle/unpickle usage when # restoring Python lists/containers. rb_state_trim = None replay_file = None obs_file = None try: rb_state = self.replay_buffer.state_dict() n = int(self.replay_buffer.size) if n > 0: rb_state_trim = { "observations": rb_state["observations"][:n].copy(), "next_observations": rb_state["next_observations"][:n].copy(), "actions": rb_state["actions"][:n].copy(), "rewards": rb_state["rewards"][:n].copy(), "dones": rb_state["dones"][:n].copy(), "position": int(self.replay_buffer.position), "size": int(n), "capacity": int(n), } else: rb_state_trim = rb_state except Exception: rb_state_trim = None # Prepare checkpoint (model weights + metadata). We keep hidden states # in the torch checkpoint but save large numpy arrays to separate files # with a common basename derived from the output path. checkpoint = { "config": self.config.__dict__, "diffusion_model": self.diffusion_model.state_dict(), "reward_model": self.reward_model.state_dict(), "actor_critic": self.actor_critic.state_dict(), "diffusion_opt": self.diffusion_opt.state_dict(), "reward_opt": self.reward_opt.state_dict(), "actor_opt": self.actor_opt.state_dict(), # optional LSTM hidden states saved for reproducible imagination "last_policy_hidden": self.last_policy_hidden, "last_reward_hidden": self.last_reward_hidden, } # If we have a trimmed replay buffer, write it to a compressed npz in # the same directory as the checkpoint. Also write obs_history_raw as # a stacked numpy file if present. Record filenames inside the torch # checkpoint for later restoration. base = os.path.splitext(str(out_path))[0] if rb_state_trim is not None: replay_file = base + "_replay.npz" # np.savez_compressed will store arrays with the provided keys try: np.savez_compressed(replay_file, **rb_state_trim) checkpoint["replay_buffer_file"] = replay_file except Exception: # If saving fails, do not embed large Python objects in the # torch checkpoint; simply omit the replay buffer files and # warn the caller. print("Warning: failed to write replay buffer file; skipping embedding") if self.obs_history_raw is not None and len(self.obs_history_raw) > 0: obs_file = base + "_obs.npy" try: # Stack into a single array (N, H, W, C) obs_arr = np.stack(self.obs_history_raw) np.save(obs_file, obs_arr) checkpoint["obs_history_file"] = obs_file except Exception: print("Warning: failed to write obs_history file; skipping embedding") torch.save(checkpoint, out_path)
[docs] def load_checkpoint(self, path: Optional[str] = None): """Load model checkpoint. Args: path: Optional path to checkpoint. If None, the default `checkpoints/diamond/checkpoint.pt` is loaded. If a bare filename is provided, we try `checkpoints/diamond/<filename>`; if a path with directory components is provided we use it directly. """ # Resolve path similarly to save_checkpoint behaviour default_dir = "checkpoints/diamond" if path is None: fpath = os.path.join(default_dir, "checkpoint.pt") else: if os.path.exists(path): fpath = path else: alt = os.path.join(default_dir, path) if os.path.exists(alt): fpath = alt else: raise FileNotFoundError(f"Checkpoint not found at {path} or {alt}") # Use full (unsafe) load to restore Python objects (numpy arrays, lists) # required for replay buffer and obs history restoration. # Load the torch checkpoint (weights + metadata). This file should be # safe to load because it contains only tensor state dicts and small # metadata fields. Larger numpy arrays may be stored separately and # are loaded below when present. checkpoint = torch.load(fpath, map_location=self.device, weights_only=False) self.diffusion_model.load_state_dict(checkpoint["diffusion_model"]) self.reward_model.load_state_dict(checkpoint["reward_model"]) self.actor_critic.load_state_dict(checkpoint["actor_critic"]) self.diffusion_opt.load_state_dict(checkpoint["diffusion_opt"]) self.reward_opt.load_state_dict(checkpoint["reward_opt"]) self.actor_opt.load_state_dict(checkpoint["actor_opt"]) # restore optional last hidden states (move to configured device) if ( "last_policy_hidden" in checkpoint and checkpoint["last_policy_hidden"] is not None ): h, c = checkpoint["last_policy_hidden"] self.last_policy_hidden = (h.to(self.device), c.to(self.device)) else: self.last_policy_hidden = None if ( "last_reward_hidden" in checkpoint and checkpoint["last_reward_hidden"] is not None ): h, c = checkpoint["last_reward_hidden"] self.last_reward_hidden = (h.to(self.device), c.to(self.device)) else: self.last_reward_hidden = None # restore replay buffer from separate npz file if provided; fall back # to embedded dict when necessary. if "replay_buffer_file" in checkpoint: try: replay_file = checkpoint["replay_buffer_file"] with np.load(replay_file, allow_pickle=False) as data: rb_state = {k: data[k] for k in data.files} self.replay_buffer.load_state_dict(rb_state) except Exception: print( "Warning: failed to load replay_buffer from file; trying embedded state" ) try: self.replay_buffer.load_state_dict( checkpoint.get("replay_buffer", {}) ) except Exception: print("Warning: failed to load replay_buffer from checkpoint") elif "replay_buffer" in checkpoint and checkpoint["replay_buffer"] is not None: try: self.replay_buffer.load_state_dict(checkpoint["replay_buffer"]) except Exception: print("Warning: failed to load replay_buffer from checkpoint") # restore obs_history_raw from separate .npy file if present if "obs_history_file" in checkpoint: try: obs_file = checkpoint["obs_history_file"] obs_arr = np.load(obs_file, allow_pickle=False) self.obs_history_raw = [o for o in obs_arr] self.obs_history = [ o.astype(np.float32) / 255.0 for o in self.obs_history_raw ] except Exception: print( "Warning: failed to load obs_history from file; trying embedded state" ) try: self.obs_history_raw = checkpoint.get("obs_history_raw", []) self.obs_history = [ o.astype(np.float32) / 255.0 for o in self.obs_history_raw ] except Exception: print("Warning: failed to load obs_history_raw from checkpoint") elif ( "obs_history_raw" in checkpoint and checkpoint["obs_history_raw"] is not None ): try: self.obs_history_raw = checkpoint["obs_history_raw"] self.obs_history = [ o.astype(np.float32) / 255.0 for o in self.obs_history_raw ] except Exception: print("Warning: failed to load obs_history_raw from checkpoint")
[docs] def train_diamond( game: str, seed: int = 0, preset: Optional[str] = None, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): """Train DIAMOND on a specific game.""" config = DiamondConfig( game=game, seed=seed, preset=preset if preset else None, device=device, ) agent = DiamondAgent(config) agent.train()
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--game", type=str, default="Breakout-v5") parser.add_argument("--seed", type=int, default=0) parser.add_argument( "--preset", type=str, default=None, choices=["small", "medium", "large"] ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" ) args = parser.parse_args() train_diamond(args.game, args.seed, args.preset, args.device)