Source code for world_models.configs.diamond_config

from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class ModelPreset:
    """Model architecture preset for different hardware tiers."""

    diffusion_channels: List[int]
    diffusion_res_blocks: int
    diffusion_cond_dim: int
    reward_channels: List[int]
    reward_lstm_dim: int
    actor_channels: List[int]
    actor_lstm_dim: int


MODEL_PRESETS = {
    "small": ModelPreset(
        diffusion_channels=[32, 32, 32, 32],
        diffusion_res_blocks=2,
        diffusion_cond_dim=128,
        reward_channels=[16, 16, 16, 16],
        reward_lstm_dim=256,
        actor_channels=[16, 16, 32, 32],
        actor_lstm_dim=256,
    ),
    "medium": ModelPreset(
        diffusion_channels=[64, 64, 64, 64],
        diffusion_res_blocks=2,
        diffusion_cond_dim=256,
        reward_channels=[32, 32, 32, 32],
        reward_lstm_dim=512,
        actor_channels=[32, 32, 64, 64],
        actor_lstm_dim=512,
    ),
    "large": ModelPreset(
        diffusion_channels=[128, 128, 128, 128],
        diffusion_res_blocks=3,
        diffusion_cond_dim=512,
        reward_channels=[64, 64, 64, 64],
        reward_lstm_dim=1024,
        actor_channels=[64, 64, 128, 128],
        actor_lstm_dim=1024,
    ),
}


[docs] @dataclass class DiamondConfig: # Preset selection (overrides manual model config if set) preset: Optional[str] = None # "small", "medium", "large", or None def __post_init__(self): if self.preset and self.preset in MODEL_PRESETS: p = MODEL_PRESETS[self.preset] self.diffusion_channels = p.diffusion_channels self.diffusion_res_blocks = p.diffusion_res_blocks self.diffusion_cond_dim = p.diffusion_cond_dim self.reward_channels = list( p.reward_channels ) # convert tuple for dataclass self.reward_lstm_dim = p.reward_lstm_dim self.actor_channels = list(p.actor_channels) self.actor_lstm_dim = p.actor_lstm_dim # Environment game: str = "Breakout-v5" obs_size: int = 64 frameskip: int = 4 max_noop: int = 30 terminate_on_life_loss: bool = True reward_clip: List[int] = field(default_factory=lambda: [-1, 0, 1]) # Frame stacking (observation conditioning) num_conditioning_frames: int = 4 # Diffusion model (Dθ) - used if preset is None diffusion_channels: List[int] = field(default_factory=lambda: [64, 64, 64, 64]) diffusion_res_blocks: int = 2 diffusion_cond_dim: int = 256 # EDM hyperparameters sigma_data: float = 0.5 sigma_min: float = 0.002 sigma_max: float = 80.0 rho: int = 7 p_mean: float = -0.4 p_std: float = 1.2 # Diffusion sampling sampling_method: str = "euler" num_sampling_steps: int = 3 # Reward/Termination model (Rψ) - used if preset is None reward_channels: List[int] = field(default_factory=lambda: [32, 32, 32, 32]) reward_res_blocks: int = 2 reward_cond_dim: int = 128 reward_lstm_dim: int = 512 burn_in_length: int = 4 # RL Agent (actor-critic) - used if preset is None actor_channels: List[int] = field(default_factory=lambda: [32, 32, 64, 64]) actor_res_blocks: int = 1 actor_lstm_dim: int = 512 # Training num_epochs: int = 1000 training_steps_per_epoch: int = 400 batch_size: int = 32 environment_steps_per_epoch: int = 100 epsilon_greedy: float = 0.01 # RL hyperparameters imagination_horizon: int = 15 discount_factor: float = 0.985 entropy_weight: float = 0.001 lambda_returns: float = 0.95 # Optimization learning_rate: float = 1e-4 adam_epsilon: float = 1e-8 weight_decay_diffusion: float = 1e-2 weight_decay_reward: float = 1e-2 weight_decay_actor: float = 0.0 # Device device: str = "cuda" # Logging log_interval: int = 10 eval_interval: int = 50 save_interval: int = 100 # Seeds num_seeds: int = 5 seed: int = 0
# Atari 100k benchmark games ATARI_100K_GAMES = [ "Alien-v5", "Amidar-v5", "Assault-v5", "Asterix-v5", "BankHeist-v5", "BattleZone-v5", "Boxing-v5", "Breakout-v5", "ChopperCommand-v5", "CrazyClimber-v5", "DemonAttack-v5", "Freeway-v5", "Frostbite-v5", "Gopher-v5", "Hero-v5", "Jamesbond-v5", "Kangaroo-v5", "Krull-v5", "KungFuMaster-v5", "MsPacman-v5", "Pong-v5", "PrivateEye-v5", "Qbert-v5", "RoadRunner-v5", "Seaquest-v5", "UpNDown-v5", ] # Human normalized scores (for evaluation) HUMAN_SCORES = { "Alien-v5": 7127.7, "Amidar-v5": 1719.5, "Assault-v5": 742.0, "Asterix-v5": 8503.3, "BankHeist-v5": 753.1, "BattleZone-v5": 37187.5, "Boxing-v5": 12.1, "Breakout-v5": 30.5, "ChopperCommand-v5": 7387.8, "CrazyClimber-v5": 35829.4, "DemonAttack-v5": 1971.0, "Freeway-v5": 29.6, "Frostbite-v5": 4334.7, "Gopher-v5": 2412.5, "Hero-v5": 30826.4, "Jamesbond-v5": 302.8, "Kangaroo-v5": 3035.0, "Krull-v5": 2665.5, "KungFuMaster-v5": 22736.3, "MsPacman-v5": 6951.6, "Pong-v5": 14.6, "PrivateEye-v5": 69571.3, "Qbert-v5": 13455.0, "RoadRunner-v5": 7845.0, "Seaquest-v5": 42054.7, "UpNDown-v5": 11693.2, } RANDOM_SCORES = { "Alien-v5": 227.8, "Amidar-v5": 5.8, "Assault-v5": 222.4, "Asterix-v5": 210.0, "BankHeist-v5": 14.2, "BattleZone-v5": 2360.0, "Boxing-v5": 0.1, "Breakout-v5": 1.7, "ChopperCommand-v5": 811.0, "CrazyClimber-v5": 10780.5, "DemonAttack-v5": 152.1, "Freeway-v5": 0.0, "Frostbite-v5": 65.2, "Gopher-v5": 257.6, "Hero-v5": 1027.0, "Jamesbond-v5": 29.0, "Kangaroo-v5": 52.0, "Krull-v5": 1598.0, "KungFuMaster-v5": 258.5, "MsPacman-v5": 307.3, "Pong-v5": -20.7, "PrivateEye-v5": 24.9, "Qbert-v5": 163.9, "RoadRunner-v5": 11.5, "Seaquest-v5": 68.4, "UpNDown-v5": 533.4, }