Training Guide#
This guide covers how to train world models in TorchWM.
Overview#
TorchWM supports training multiple world model algorithms with a unified interface.
Basic Training Flow#
Select an algorithm
Override environment, dataset, or optimization parameters
Initialize the agent
Call
train()and monitor logs/checkpoints
The simplest path is the top-level torchwm API:
import torchwm
agent = torchwm.create_model(
"dreamer",
env_backend="dmc",
env="walker-walk",
total_steps=1_000_000,
)
agent.train()
For research code, the lower-level config and agent classes remain available.
Dreamer Training#
Preferred application API:
import torchwm
agent = torchwm.create_model(
"dreamer",
env_backend="dmc",
env="walker-walk",
total_steps=1_000_000,
)
agent.train()
Equivalent direct API:
from torchwm import DreamerAgent, DreamerConfig
cfg = DreamerConfig()
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
cfg.total_steps = 1_000_000
agent = DreamerAgent(cfg)
agent.train()
JEPA Training#
import torchwm
agent = torchwm.create_model(
"jepa",
dataset="imagenet",
batch_size=64,
epochs=100,
)
agent.train()
IRIS Training#
IRISAgent needs constructor arguments such as action_size and device in
addition to its config, so pass those as constructor overrides:
import torch
import torchwm
agent = torchwm.create_model(
"iris",
env="ALE/Pong-v5",
total_epochs=100,
action_size=4,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
Custom Training Loop#
For advanced users, implement custom training:
from torchwm import DreamerAgent, ReplayBuffer
agent = DreamerAgent(cfg)
memory = ReplayBuffer(
size=100_000,
obs_shape=(3, 64, 64),
action_size=6,
seq_len=50,
batch_size=50,
)
for step in range(cfg.total_steps):
# Collect experience
experience = agent.collect_episode()
memory.add_episode(experience)
# Train
if step % cfg.update_steps == 0:
batch = memory.sample_batch()
metrics = agent.update(batch)
# Log
if step % cfg.scalar_freq == 0:
print(f"Step {step}: {metrics}")
Configuration#
All training is controlled via config objects:
Common Parameters#
seed: Random seeddevice: Training devicetotal_steps/epochs: Training durationbatch_size: Batch sizelearning_rate: Learning rategrad_clip_norm: Gradient clipping
Logging#
enable_wandb: Weights & Biases logginglog_dir: Base log directory for supported agentsscalar_freq/log_interval: Metric logging cadence, depending on the agentcheckpoint_interval: Save frequency
Starter YAML configs#
DIAMOND, IRIS, and JEPA include starter experiment YAML files in
world_models/configs/experiments/. Use them with the unified CLI and optional
OmegaConf/Hydra-style dot-list overrides:
torchwm train diamond --config world_models/configs/experiments/diamond.yaml preset=small seed=1
torchwm train iris --config world_models/configs/experiments/iris.yaml total_epochs=100 env=ALE/Breakout-v5
torchwm train jepa --config world_models/configs/experiments/jepa.yaml optimization.epochs=50 data.batch_size=128
Add --print-config to inspect the composed configuration without launching a run.
Environment Setup#
DMC#
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
DeepMind Lab#
cfg.env_backend = "dmlab"
cfg.env = "rooms_collect_good_objects_train"
cfg.dmlab_action_repeat = 4
Gym#
cfg.env_backend = "gym"
cfg.env = "Pendulum-v1"
Brax#
cfg.env_backend = "brax"
cfg.env = "ant"
cfg.brax_backend = "generalized"
Unity ML-Agents#
cfg.env_backend = "unity_mlagents"
cfg.unity_file_name = "path/to/env.exe"
Monitoring Training#
TensorBoard#
tensorboard --logdir runs
Weights & Biases#
cfg.enable_wandb = True
cfg.wandb_project = "torchwm"
cfg.wandb_entity = "your-entity"
Checkpointing#
Models are automatically saved:
# Resume training
cfg.restore = True
cfg.checkpoint_path = "path/to/checkpoint"
Distributed Training#
For multi-GPU training:
cfg.num_gpus = 4
# TorchWM handles distributed setup automatically
Best Practices#
Start small: Use short episodes and few steps for debugging
Monitor metrics: Watch loss curves and environment rewards
Tune hyperparameters: Adjust learning rates and batch sizes
Use checkpoints: Save frequently and resume from failures
Log experiments: Use WandB or TensorBoard for tracking