Training Guide#

This guide covers how to train world models in TorchWM.

Overview#

TorchWM supports training multiple world model algorithms with a unified interface.

Basic Training Flow#

  1. Select an algorithm

  2. Override environment, dataset, or optimization parameters

  3. Initialize the agent

  4. Call train() and monitor logs/checkpoints

The simplest path is the top-level torchwm API:

import torchwm

agent = torchwm.create_model(
    "dreamer",
    env_backend="dmc",
    env="walker-walk",
    total_steps=1_000_000,
)
agent.train()

For research code, the lower-level config and agent classes remain available.

Dreamer Training#

Preferred application API:

import torchwm

agent = torchwm.create_model(
    "dreamer",
    env_backend="dmc",
    env="walker-walk",
    total_steps=1_000_000,
)
agent.train()

Equivalent direct API:

from torchwm import DreamerAgent, DreamerConfig

cfg = DreamerConfig()
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
cfg.total_steps = 1_000_000

agent = DreamerAgent(cfg)
agent.train()

JEPA Training#

import torchwm

agent = torchwm.create_model(
    "jepa",
    dataset="imagenet",
    batch_size=64,
    epochs=100,
)
agent.train()

IRIS Training#

IRISAgent needs constructor arguments such as action_size and device in addition to its config, so pass those as constructor overrides:

import torch
import torchwm

agent = torchwm.create_model(
    "iris",
    env_name="Pong-v5",
    total_epochs=100,
    action_size=4,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)

Custom Training Loop#

For advanced users, implement custom training:

from torchwm import DreamerAgent, ReplayBuffer

agent = DreamerAgent(cfg)
memory = ReplayBuffer(
    size=100_000,
    obs_shape=(3, 64, 64),
    action_size=6,
    seq_len=50,
    batch_size=50,
)

for step in range(cfg.total_steps):
    # Collect experience
    experience = agent.collect_episode()
    memory.add_episode(experience)

    # Train
    if step % cfg.update_steps == 0:
        batch = memory.sample_batch()
        metrics = agent.update(batch)

    # Log
    if step % cfg.log_every == 0:
        print(f"Step {step}: {metrics}")

Configuration#

All training is controlled via config objects:

Common Parameters#

  • seed: Random seed

  • device: Training device

  • total_steps/epochs: Training duration

  • batch_size: Batch size

  • learning_rate: Learning rate

  • grad_clip_norm: Gradient clipping

Logging#

  • enable_wandb: Weights & Biases logging

  • log_dir: TensorBoard log directory

  • checkpoint_interval: Save frequency

Environment Setup#

DMC#

cfg.env_backend = "dmc"
cfg.env = "walker-walk"

Gym#

cfg.env_backend = "gym"
cfg.env = "Pendulum-v1"

Brax#

cfg.env_backend = "brax"
cfg.env = "ant"
cfg.brax_backend = "generalized"

Unity ML-Agents#

cfg.env_backend = "unity_mlagents"
cfg.unity_file_name = "path/to/env.exe"

Monitoring Training#

TensorBoard#

tensorboard --logdir runs

Weights & Biases#

cfg.enable_wandb = True
cfg.wandb_project = "torchwm"
cfg.wandb_entity = "your-entity"

Checkpointing#

Models are automatically saved:

# Resume training
cfg.restore = True
cfg.checkpoint_path = "path/to/checkpoint"

Distributed Training#

For multi-GPU training:

cfg.num_gpus = 4
# TorchWM handles distributed setup automatically

Best Practices#

  1. Start small: Use short episodes and few steps for debugging

  2. Monitor metrics: Watch loss curves and environment rewards

  3. Tune hyperparameters: Adjust learning rates and batch sizes

  4. Use checkpoints: Save frequently and resume from failures

  5. Log experiments: Use WandB or TensorBoard for tracking