Training Guide#
This guide covers how to train world models in TorchWM.
Overview#
TorchWM supports training multiple world model algorithms with a unified interface.
Basic Training Flow#
Select an algorithm and create config
Set environment/dataset parameters
Configure training hyperparameters
Initialize agent and call train()
Dreamer Training#
from world_models.models import DreamerAgent
from world_models.configs import DreamerConfig
# Configure
cfg = DreamerConfig()
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
cfg.total_steps = 1_000_000
# Train
agent = DreamerAgent(cfg)
agent.train()
JEPA Training#
from world_models.models import JEPAAgent
from world_models.configs import JEPAConfig
cfg = JEPAConfig()
cfg.dataset = "imagenet"
cfg.batch_size = 64
cfg.epochs = 100
agent = JEPAAgent(cfg)
agent.train()
IRIS Training#
from world_models.models import IRISAgent
from world_models.configs import IRISConfig
cfg = IRISConfig()
cfg.env_name = "Pong-v5"
cfg.total_epochs = 100
agent = IRISAgent(cfg)
agent.train()
Custom Training Loop#
For advanced users, implement custom training:
from world_models.memory import DreamerMemory
from world_models.models import DreamerAgent
agent = DreamerAgent(cfg)
memory = DreamerMemory(cfg)
for step in range(cfg.total_steps):
# Collect experience
experience = agent.collect_episode()
memory.add_episode(experience)
# Train
if step % cfg.update_steps == 0:
batch = memory.sample_batch()
metrics = agent.update(batch)
# Log
if step % cfg.log_every == 0:
print(f"Step {step}: {metrics}")
Configuration#
All training is controlled via config objects:
Common Parameters#
seed: Random seeddevice: Training devicetotal_steps/epochs: Training durationbatch_size: Batch sizelearning_rate: Learning rategrad_clip_norm: Gradient clipping
Logging#
enable_wandb: Weights & Biases logginglog_dir: TensorBoard log directorycheckpoint_interval: Save frequency
Environment Setup#
DMC#
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
Gym#
cfg.env_backend = "gym"
cfg.env = "Pendulum-v1"
Unity ML-Agents#
cfg.env_backend = "unity_mlagents"
cfg.unity_file_name = "path/to/env.exe"
Monitoring Training#
TensorBoard#
tensorboard --logdir runs
Weights & Biases#
cfg.enable_wandb = True
cfg.wandb_project = "torchwm"
cfg.wandb_entity = "your-entity"
Checkpointing#
Models are automatically saved:
# Resume training
cfg.restore = True
cfg.checkpoint_path = "path/to/checkpoint"
Distributed Training#
For multi-GPU training:
cfg.num_gpus = 4
# TorchWM handles distributed setup automatically
Best Practices#
Start small: Use short episodes and few steps for debugging
Monitor metrics: Watch loss curves and environment rewards
Tune hyperparameters: Adjust learning rates and batch sizes
Use checkpoints: Save frequently and resume from failures
Log experiments: Use WandB or TensorBoard for tracking