Source code for world_models.configs.dreamer_config

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from world_models.configs.serialization import SerializableConfigMixin


[docs] @dataclass class DreamerConfig(SerializableConfigMixin): """Configuration container for Dreamer training, evaluation, and environment setup. This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by `DreamerAgent`. """ # Environment selection. # dmc: DeepMind Control Suite # dmlab: DeepMind Lab 3D navigation tasks # gym: generic Gym/Gymnasium env IDs or prebuilt env instances # mujoco: Gymnasium MuJoCo task IDs or native MuJoCo XML/MJB # robotics: Gymnasium Robotics env IDs (including legacy MuJoCo v2/v3) # procgen: Procgen procedurally generated benchmark games # unity_mlagents: Unity ML-Agents executable # brax: JAX/Brax continuous-control environments env_backend: str = "dmc" env: str = "walker-walk" env_instance: Any = None image_size: tuple[int, int] = (64, 64) gym_render_mode: str = "rgb_array" # DeepMind Lab options. dmlab_action_repeat is native DMLab frame # repeat; Dreamer action_repeat is still applied by the shared wrapper # stack outside the backend adapter. dmlab_action_repeat: int = 4 dmlab_action_set: Any = None dmlab_observations: Any = None dmlab_config: Any = None dmlab_renderer: str = "hardware" # Procgen options. Use env values like "coinrun" or "procgen-coinrun-v0". procgen_distribution_mode: str = "easy" procgen_num_levels: int = 0 procgen_start_level: Any = None # MuJoCo options. Leave mujoco_xml_path unset to auto-detect whether # `env` is a Gymnasium task ID or a native MJCF/MJB source. mujoco_xml_path: Any = None mujoco_xml_string: Any = None mujoco_binary_path: Any = None mujoco_camera: Any = None mujoco_frame_skip: int = 1 mujoco_reset_noise_scale: float = 0.0 # Brax options. brax_backend: str = "generalized" brax_jit: bool = True brax_auto_reset: bool = False # Suppress noisy optional MuJoCo/MJX Warp import messages emitted during # Brax imports. These messages are harmless when Warp is not installed # but can clutter logs; enable suppression by default. brax_suppress_warp_warnings: bool = True # Unity ML-Agents options. unity_file_name: Any = None unity_behavior_name: Any = None unity_worker_id: int = 0 unity_base_port: int = 5005 unity_no_graphics: bool = True unity_time_scale: float = 20.0 unity_quality_level: int = 1 algo: str = "Dreamerv1" exp_name: str = "lr1e-3" train: bool = True evaluate: bool = False seed: int = 1 no_gpu: bool = False max_episode_length: int = 1000 buffer_size: int = 800000 time_limit: int = 1000 cnn_activation_function: str = "relu" dense_activation_function: str = "elu" obs_embed_size: int = 1024 num_units: int = 400 deter_size: int = 200 stoch_size: int = 30 action_repeat: int = 2 action_noise: float = 0.3 total_steps: int = 5_000_000 seed_steps: int = 5000 update_steps: int = 100 collect_steps: int = 1000 batch_size: int = 50 train_seq_len: int = 50 imagine_horizon: int = 15 use_disc_model: bool = False free_nats: float = 3.0 discount: float = 0.99 td_lambda: float = 0.95 kl_loss_coeff: float = 1.0 kl_alpha: float = 0.8 disc_loss_coeff: float = 10.0 num_buckets: int = 255 symlog_range: float = 10.0 model_learning_rate: float = 6e-4 actor_learning_rate: float = 8e-5 value_learning_rate: float = 8e-5 adam_epsilon: float = 1e-7 grad_clip_norm: float = 100.0 use_amp: bool = True test: bool = False test_interval: int = 10000 test_episodes: int = 10 scalar_freq: int = 1_000 log_video_freq: int = -1 max_videos_to_save: int = 2 video_format: str = "gif" # "gif" or "mp4" video_fps: int = 20 checkpoint_interval: int = 10000 checkpoint_path: str = "" restore: bool = False experience_replay: str = "" render: bool = False # Logging options enable_wandb: bool = False wandb_api_key: str = "" # Required if enable_wandb is True wandb_project: str = "torchwm" wandb_entity: str = "" log_dir: str = "runs" logdir: Any = None # Base directory for DreamerAgent-created relative log directories. # If unset, DreamerAgent uses TORCHWM_DATA_DIR or log_dir instead of # writing into the package source tree. data_dir: Any = None log_level: str = "INFO" log_file: Any = None enable_tensorboard: bool = False enable_console_metrics: bool = True enable_jsonl: bool = True jsonl_filename: str = "metrics.jsonl" log_system_stats_freq: int = 1_000 detect_anomaly: bool = False