Configs Reference#
This page documents all configuration classes in TorchWM.
DreamerConfig#
Configuration for Dreamer agent training.
@dataclass
class DreamerConfig:
# Environment
env_backend: str = "dmc"
env: str = "walker-walk"
env_instance: Optional[object] = None
image_size: Tuple[int, int] = (64, 64)
gym_render_mode: str = "rgb_array"
# DeepMind Lab (optional)
dmlab_action_repeat: int = 4
dmlab_action_set: Optional[object] = None
dmlab_observations: Optional[list[str]] = None
dmlab_config: Optional[dict] = None
dmlab_renderer: str = "hardware"
# Procgen (optional)
procgen_distribution_mode: str = "easy"
procgen_num_levels: int = 0
procgen_start_level: Optional[int] = None
# MuJoCo (optional)
mujoco_xml_path: Optional[str] = None
mujoco_xml_string: Optional[str] = None
mujoco_binary_path: Optional[str] = None
mujoco_camera: Optional[Union[str, int]] = None
mujoco_frame_skip: int = 1
mujoco_reset_noise_scale: float = 0.0
# Brax (optional)
brax_backend: str = "generalized"
brax_jit: bool = True
brax_auto_reset: bool = False
brax_suppress_warp_warnings: bool = True
# Unity ML-Agents
unity_file_name: Optional[str] = None
unity_behavior_name: Optional[str] = None
unity_worker_id: int = 0
unity_base_port: int = 5005
unity_no_graphics: bool = True
unity_time_scale: float = 20.0
unity_quality_level: int = 1
# Training
algo: str = "Dreamerv1"
exp_name: str = "lr1e-3"
train: bool = True
evaluate: bool = False
seed: int = 1
no_gpu: bool = False
max_episode_length: int = 1000
buffer_size: int = 800000
time_limit: int = 1000
# Model
cnn_activation_function: str = "relu"
dense_activation_function: str = "elu"
obs_embed_size: int = 1024
num_units: int = 400
deter_size: int = 200
stoch_size: int = 30
action_repeat: int = 2
action_noise: float = 0.3
# Learning
model_learning_rate: float = 6e-4
actor_learning_rate: float = 8e-5
value_learning_rate: float = 8e-5
total_steps: int = int(5e6)
seed_steps: int = 5000
update_steps: int = 100
collect_steps: int = 1000
batch_size: int = 50
train_seq_len: int = 50
imagine_horizon: int = 15
use_disc_model: bool = False
# Loss
free_nats: float = 3.0
discount: float = 0.99
td_lambda: float = 0.95
kl_loss_coeff: float = 1.0
kl_alpha: float = 0.8
disc_loss_coeff: float = 10.0
num_buckets: int = 255
symlog_range: float = 10.0
# Optimization
adam_epsilon: float = 1e-7
grad_clip_norm: float = 100.0
# Evaluation and checkpointing
test: bool = False
test_interval: int = 10000
test_episodes: int = 10
scalar_freq: int = int(1e3)
log_video_freq: int = -1
max_videos_to_save: int = 2
video_format: str = "gif"
video_fps: int = 20
checkpoint_interval: int = 10000
checkpoint_path: str = ""
restore: bool = False
experience_replay: str = ""
render: bool = False
# WandB
enable_wandb: bool = False
wandb_api_key: str = ""
wandb_project: str = "torchwm"
wandb_entity: str = ""
log_dir: str = "runs"
data_dir: Optional[str] = None
log_level: str = "INFO"
log_file: Optional[str] = None
enable_tensorboard: bool = False
enable_console_metrics: bool = True
enable_jsonl: bool = True
jsonl_filename: str = "metrics.jsonl"
log_system_stats_freq: int = int(1e3)
detect_anomaly: bool = False
JEPAConfig#
Configuration for JEPA training.
@dataclass
class JEPAConfig:
# Meta
use_bfloat16: bool = False
model_name: str = "vit_base"
load_checkpoint: bool = False
read_checkpoint: Optional[str] = None
copy_data: bool = False
pred_depth: int = 6
pred_emb_dim: int = 384
# Data
dataset: str = "imagenet"
val_split: Optional[float] = None
use_gaussian_blur: bool = True
use_horizontal_flip: bool = True
use_color_distortion: bool = True
color_jitter_strength: float = 0.5
batch_size: int = 64
pin_mem: bool = True
num_workers: int = 8
root_path: str = os.environ.get("IMAGENET_ROOT", "/data/imagenet")
image_folder: str = "train"
crop_size: int = 224
crop_scale: Tuple[float, float] = (0.67, 1.0)
download: bool = False
# Mask
allow_overlap: bool = False
patch_size: int = 16
num_enc_masks: int = 1
min_keep: int = 4
enc_mask_scale: Tuple[float, float] = (0.15, 0.2)
num_pred_masks: int = 1
pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
aspect_ratio: Tuple[float, float] = (0.75, 1.5)
# Optimization
ema: Tuple[float, float] = (0.996, 1.0)
ipe_scale: float = 1.0
weight_decay: float = 0.04
final_weight_decay: float = 0.4
epochs: int = 300
warmup: int = 40
start_lr: float = 1e-6
lr: float = 1.5e-4
final_lr: float = 1e-6
# Logging
folder: str = "results/jepa"
write_tag: str = "jepa_run"
enable_wandb: bool = False
wandb_api_key: str = ""
wandb_project: str = "torchwm"
wandb_entity: str = ""
enable_sweep: bool = False
sweep_config: Dict[str, Any] = {}
IRISConfig#
Configuration for IRIS training.
@dataclass
class IRISConfig:
# Discrete Autoencoder (VQVAE)
frame_height: int = 64
frame_width: int = 64
frame_channels: int = 3
vocab_size: int = 512
tokens_per_frame: int = 16
token_embedding_dim: int = 512
encoder_channels: int = 64
encoder_layers: int = 4
encoder_residual_blocks: int = 2
decoder_depth: int = 32
reconstruction_weight: float = 1.0
commitment_weight: float = 0.25
perceptual_weight: float = 1.0
# Transformer (World Model)
transformer_timesteps: int = 20
transformer_embed_dim: int = 256
transformer_layers: int = 10
transformer_heads: int = 4
transformer_dropout: float = 0.1
# Actor-Critic
imagination_horizon: int = 15
discount: float = 0.99
td_lambda: float = 0.9
entropy_coef: float = 0.01
actor_hidden_size: int = 512
actor_layers: int = 4
value_hidden_size: int = 512
value_layers: int = 3
# Training
total_epochs: int = 600
collection_epochs: int = 500
env_steps_per_epoch: int = 200
training_steps_per_epoch: int = 250
model_learning_rate: float = 1e-4
actor_learning_rate: float = 1e-4
value_learning_rate: float = 1e-4
adam_beta1: float = 0.9
adam_beta2: float = 0.999
weight_decay: float = 0.01
grad_clip_norm: float = 10.0
collect_epsilon: float = 0.1
eval_temperature: float = 0.1
start_autoencoder_after: int = 1
start_transformer_after: int = 15
start_actor_critic_after: int = 35
autoencoder_batch_size: int = 256
transformer_batch_size: int = 64
actor_critic_batch_size: int = 64
# Atari 100k Benchmark
atari_100k: bool = True
max_env_steps: int = 100000
env_backend: str = "gym"
env: str = "ALE/Pong-v5"
action_repeat: int = 4
# Logging
log_interval: int = 1000
eval_episodes: int = 100
checkpoint_interval: int = 50
DiamondConfig#
Configuration for Diamond (Diffusion + RL) training.
@dataclass
class DiamondConfig:
# Preset
preset: Optional[str] = None # "small", "medium", "large"
# Environment
game: str = "Breakout-v5"
seed: int = 0
obs_size: int = 64
frameskip: int = 4
max_noop: int = 30
terminate_on_life_loss: bool = True
reward_clip: List[int] = field(default_factory=lambda: [-1, 0, 1])
num_conditioning_frames: int = 4
# Diffusion Model
diffusion_channels: List[int] = field(default_factory=lambda: [64, 64, 64, 64])
diffusion_res_blocks: int = 2
diffusion_cond_dim: int = 256
sigma_data: float = 0.5
sigma_min: float = 0.002
sigma_max: float = 80.0
rho: int = 7
p_mean: float = -0.4
p_std: float = 1.2
sampling_method: str = "euler"
num_sampling_steps: int = 3
# Reward/Termination Model
reward_channels: List[int] = field(default_factory=lambda: [32, 32, 32, 32])
reward_res_blocks: int = 2
reward_cond_dim: int = 128
reward_lstm_dim: int = 512
burn_in_length: int = 4
# RL Agent
actor_channels: List[int] = field(default_factory=lambda: [32, 32, 64, 64])
actor_res_blocks: int = 1
actor_lstm_dim: int = 512
# Training
num_epochs: int = 1000
training_steps_per_epoch: int = 400
batch_size: int = 32
environment_steps_per_epoch: int = 100
epsilon_greedy: float = 0.01
imagination_horizon: int = 15
discount_factor: float = 0.985
entropy_weight: float = 0.001
lambda_returns: float = 0.95
learning_rate: float = 1e-4
adam_epsilon: float = 1e-8
weight_decay_diffusion: float = 1e-2
weight_decay_reward: float = 1e-2
weight_decay_actor: float = 0.0
# Device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# Logging
log_interval: int = 10
eval_interval: int = 50
save_interval: int = 100
# Operator parameters
operator_state_dim: int = 32
operator_action_dim: int = 4
Usage Patterns#
Basic Configuration#
from torchwm import DreamerConfig
cfg = DreamerConfig()
cfg.env = "walker-walk"
cfg.total_steps = 1_000_000
Environment-Specific Configs#
# DMC
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
# DeepMind Lab
cfg.env_backend = "dmlab"
cfg.env = "rooms_collect_good_objects_train"
cfg.dmlab_action_repeat = 4
# Gym
cfg.env_backend = "gym"
cfg.env = "Pendulum-v1"
# MuJoCo example:
cfg.env_backend = "mujoco"
cfg.env = "Humanoid-v4" # or "models/cartpole.xml"
cfg.mujoco_camera = None # native MJCF/MJB only
cfg.mujoco_frame_skip = 4 # native MJCF/MJB only
# Brax example:
cfg.env_backend = "brax"
cfg.env = "ant"
cfg.brax_backend = "generalized"
# Unity
cfg.env_backend = "unity_mlagents"
cfg.unity_file_name = "env.exe"
Training Configs#
# Basic training
cfg.batch_size = 50
cfg.model_learning_rate = 6e-4
cfg.total_steps = 5_000_000
# Logging
cfg.enable_wandb = True
cfg.wandb_project = "my_project"
# Checkpointing
cfg.checkpoint_interval = 100_000
Advanced Configs#
# Custom model sizes
cfg.obs_embed_size = 2048
cfg.num_units = 600
cfg.deter_size = 300
# Exploration noise (Dreamer)
cfg.action_noise = 0.5
# Loss weights
cfg.kl_loss_coeff = 0.1
cfg.free_nats = 1.0
Experiment YAML and OmegaConf overrides#
TorchWM provides a shared experiment configuration layer in
world_models.experiments. Training entrypoints can compose their Python
configuration defaults with a YAML file and Hydra/OmegaConf-style dot-list
overrides, while still receiving plain Python dictionaries or config objects at
runtime.
Built-in YAML starters live under world_models/configs/experiments/:
diamond.yamlfor DIAMOND Atari experiments.iris.yamlfor IRIS Atari experiments.jepa.yamlfor JEPA image pretraining experiments.
Examples:
torchwm train diamond --config world_models/configs/experiments/diamond.yaml preset=small seed=1
torchwm train iris --config world_models/configs/experiments/iris.yaml total_epochs=100 env=ALE/Breakout-v5
torchwm train jepa --config world_models/configs/experiments/jepa.yaml optimization.epochs=50 data.batch_size=128
Use --print-config with these entrypoints to inspect the fully composed config
without launching a run.