Dreamer: Model-Based RL with Latent Dynamics#
Dreamer is a model-based reinforcement learning algorithm that learns a latent dynamics model from images and trains a behavior policy entirely in the latent space.
Based on papers:
Dreamer: Learning Latent Dynamics for Planning from Pixels (DreamerV1, Hafner et al., 2019)
Mastering Atari with Discrete World Models (DreamerV2, Hafner et al., 2020)
Overview#
Dreamer learns a world model from image observations, then trains an actor-critic policy entirely in the imagination of that world model. No gradients flow from the environment to the policy — the world model is the only bridge between real experience and learned behavior.
The family has two major versions, documented individually below.
DreamerV1#
Theory#
Recurrent State-Space Model (RSSM) with Gaussian Latents#
DreamerV1’s RSSM maintains a hybrid state with two components:
1. Deterministic state h_t — a GRU hidden state that captures temporal
dependencies and deterministic transitions:
2. Stochastic state s_t — a diagonal Gaussian latent variable with
stoch_size means and variances, representing uncertainty.
The model operates in two modes:
Observe mode (training — uses real observations):
Imagine mode (policy training — no observations):
World Model Loss#
The complete world model objective (V1):
Where:
V1 applies a single KL coefficient β without balancing.
Actor-Critic in Imagination#
DreamerV1 rolls out imagined trajectories using the prior dynamics and trains actor-critic purely in latent space:
Actor loss (REINFORCE with baseline):
Critic loss:
Lambda return (with fixed γ):
Examples#
import torchwm
agent = torchwm.create_model(
"dreamer",
env_backend="dmc",
env="walker-walk",
total_steps=5_000_000,
)
agent.train()
Explicit V1 config:
from torchwm import DreamerAgent, DreamerConfig
cfg = DreamerConfig()
# Select DreamerV1
cfg.algo = "Dreamerv1"
# Gaussian latent (V1 default)
cfg.stoch_size = 30 # diagonal Gaussian dimensions
cfg.deter_size = 200
# Environment
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
cfg.total_steps = 5_000_000
# KL (single coefficient)
cfg.kl_loss_coeff = 1.0
agent = DreamerAgent(cfg)
agent.train()
torchwm train dreamer --env dmc/walker-walk --algo Dreamerv1 --device cuda
DreamerV2#
Theory#
Recurrent State-Space Model (RSSM) with Categorical Latents#
DreamerV2’s RSSM maintains the same hybrid state structure as V1 but replaces Gaussian latents with discrete categorical latents:
Stochastic state s_t — a concatenation of num_categories one-hot
categorical distributions, each with classes categories:
# V2: stack of categoricals (e.g., 32 classes × 32 categories)
self.stoch = torch.cat([one_hot(logits[i]) for i in range(num_categories)], dim=-1)
Default: 32 categories × 32 classes = 1024 total latent dimensions. Discrete latents are better at representing multimodal posteriors and are critical for handling aleatoric uncertainty in complex environments like Atari.
World Model Loss with KL Balancing#
V2 introduces KL balancing — separate weighting for the prior and posterior
KL terms using stop-gradient (sg):
where α (default 0.8) weights the prior-following term higher. This prevents
the posterior from collapsing to a deterministic point mass.
Free nats: A threshold (default 3 nats) below which KL is not penalized.
The full world model objective (V2):
Discount Head#
V2 adds a learned discount (termination) head that predicts episode
continuation probability γ̂_t via binary cross-entropy:
This is critical for Atari where episodes can end due to life loss, so the discount factor must be learned rather than fixed.
Architecture Improvements#
Layer normalization in GRU and MLP layers for training stability
SiLU activations replace ELU throughout
Two-hot reward encoding replaces MSE: discretizes reward into 255 bins and predicts a softmax distribution over bins
Actor-Critic in Imagination#
Same structure as V1 but uses the learned discount γ̂_t in λ-returns:
Examples#
import torchwm
agent = torchwm.create_model(
"dreamer",
env_backend="atari",
env="PongNoFrameskip-v4",
algo="Dreamerv2",
total_steps=10_000_000,
)
agent.train()
Explicit V2 config:
from torchwm import DreamerAgent, DreamerConfig
cfg = DreamerConfig()
# Select DreamerV2
cfg.algo = "Dreamerv2"
# Categorical latent (V2)
cfg.stoch_size = 32 # number of categorical classes per category
cfg.num_categories = 32 # number of categorical distributions
cfg.deter_size = 200
# Environment
cfg.env_backend = "atari"
cfg.env = "PongNoFrameskip-v4"
cfg.total_steps = 10_000_000
# KL balancing
cfg.kl_alpha = 0.8
cfg.free_nats = 3.0
# Discount (V2 uses learned termination)
cfg.discount = 0.997
agent = DreamerAgent(cfg)
agent.train()
torchwm train dreamer --env atari/PongNoFrameskip-v4 --algo Dreamerv2 --device cuda
Differences Between DreamerV1 and DreamerV2#
Aspect |
DreamerV1 |
DreamerV2 |
|---|---|---|
Latent type |
Gaussian (continuous) |
One-hot categorical (discrete) |
Stochastic state |
|
|
KL formulation |
Single coefficient |
KL balancing with |
Discount |
Fixed |
Learned termination predictor |
Reward loss |
MSE |
Two-hot discretized cross-entropy |
Activations |
ELU |
SiLU |
Normalization |
None in GRU/MLP |
LayerNorm in GRU and MLP layers |
Atari performance |
~40% human-normalized score |
~100% human-normalized score |
Key advantage |
Simpler, fewer hyperparameters |
Better on complex/discrete environments |
Categorical vs Gaussian Latents#
V1 uses a diagonal Gaussian for the stochastic state. V2 uses a concatenation of one-hot categorical distributions:
# V1: single Gaussian
self.stoch = torch.distributions.Normal(mean, std)
# V2: stack of categoricals
self.stoch = torch.cat([one_hot(logits[i]) for i in range(num_categories)], dim=-1)
Discrete latents better capture multimodal posteriors (e.g., “the robot could be at door A or door B”) and are less prone to posterior collapse.
KL Balancing#
Formulation |
V1 |
V2 |
|---|---|---|
KL loss |
|
|
Stop-gradient |
None |
On prior in first term, posterior in second |
Effect |
Single trade-off |
Prior learns to follow posterior; posterior doesn’t collapse |
Discount Head#
V1 |
V2 |
|
|---|---|---|
Discount |
Fixed scalar |
Learned |
Purpose |
Simple time discount |
Model episode termination (life loss in Atari) |
Usage in TorchWM#
Using config directly#
from torchwm import DreamerAgent, DreamerConfig
cfg = DreamerConfig()
cfg.env_backend = "dmc"
cfg.env = "walker-walk"
cfg.total_steps = 5_000_000
agent = DreamerAgent(cfg)
agent.train()
Environment backends#
Parameter |
Default |
Description |
|---|---|---|
|
30 |
Stochastic latent dimensions |
|
200 |
Deterministic hidden size |
|
1024 |
Encoder embedding size |
|
15 |
Imagination rollout length |
|
0.99 |
Discount factor γ |
|
0.95 |
λ-return parameter |
|
1.0 |
KL divergence weight |
Learning Objectives#
World Model Loss:
Actor Loss (REINFORCE):
Critic Loss (MSE):
DreamerV2 Enhancements#
DreamerV2 introduces several improvements:
Discrete latents: Categorical latent variables instead of Gaussian
KL balancing: Separate weighting for prior/posterior KL
Discount model: Learns to predict episode termination
Layer normalization: More stable training
Configuration and Checkpoints#
Dreamer configs are serializable, so experiments can be reproduced from the YAML saved with each run or checkpoint:
from world_models.configs import DreamerConfig
from world_models.models import DreamerAgent
cfg = DreamerConfig()
cfg.env = "walker-walk"
cfg.to_yaml("configs/dreamer_walker.yaml")
agent = DreamerAgent.from_config("configs/dreamer_walker.yaml", seed=7)
agent.train()
# Checkpoints save `config.yaml` beside the weights automatically.
agent.dreamer.save("runs/walker/ckpts/model.pt")
restored = DreamerAgent.from_pretrained("runs/walker/ckpts")
print(restored.summary()["total_parameters"])
For lower-level workflows, the core Dreamer class also supports
Dreamer.from_config(...), Dreamer.from_pretrained(...),
Dreamer.summary(), and Dreamer.parameter_count().
Environment Support#
Dreamer supports multiple backends:
cfg = DreamerConfig()
cfg.env_backend = "dmc" # DeepMind Control Suite
cfg.env = "walker-walk"
cfg.env_backend = "gym" # Gym/Gymnasium
cfg.env = "Pendulum-v1"
cfg.env_backend = "dmlab" # DeepMind Lab
cfg.env = "rooms_collect_good_objects_train"
cfg.dmlab_action_repeat = 4
cfg.env_backend = "mujoco" # MuJoCo
cfg.env = "Humanoid-v4"
cfg.env_backend = "brax" # JAX/Brax
cfg.env = "ant"
cfg.env_backend = "procgen" # Procgen
cfg.env = "coinrun"
cfg.env_backend = "unity_mlagents" # Unity ML-Agents
cfg.unity_file_name = "env.exe"
CLI#
torchwm train dreamer --env dmc/walker-walk --device cuda
Config Reference#
All configuration is in world_models.configs.dreamer_config.DreamerConfig:
from world_models.configs.dreamer_config import DreamerConfig
config = DreamerConfig()
# Dreamer version
config.algo = "Dreamerv1" # or "Dreamerv2" (default: "Dreamerv1")
# Environment
config.env_backend = "dmc"
config.env = "walker-walk"
config.image_size = (64, 64)
# Model architecture
config.stoch_size = 30
config.deter_size = 200
config.obs_embed_size = 1024
# Training
config.total_steps = 5_000_000
config.batch_size = 50
config.train_seq_len = 50
config.imagine_horizon = 15
config.model_learning_rate = 6e-4
# Actor-critic
config.actor_learning_rate = 8e-5
config.value_learning_rate = 8e-5
config.discount = 0.99
config.td_lambda = 0.95
# KL (V2)
config.kl_alpha = 0.8
config.free_nats = 3.0
# Exploration
config.action_noise = 0.3
# Logging
config.scalar_freq = 10_000
config.checkpoint_interval = 100_000
config.enable_wandb = False
Key Hyperparameters#
World Model#
Parameter |
V1 Default |
V2 Default |
Effect |
|---|---|---|---|
|
30 |
32 × 32 classes |
Total stochastic capacity |
|
200 |
200 |
GRU hidden size |
|
6e-4 |
3e-4 |
World model learning rate |
|
50 |
50 |
Sequence length per batch |
|
50 |
16 |
Sequences per batch |
|
3.0 |
3.0 |
KL free bits threshold |
Actor-Critic#
Parameter |
V1 Default |
V2 Default |
Effect |
|---|---|---|---|
|
8e-5 |
8e-5 |
Policy learning rate |
|
8e-5 |
8e-5 |
Critic learning rate |
|
15 |
15 |
Imagination rollout length |
|
0.99 |
0.997 |
Discount factor |
|
0.95 |
0.95 |
λ-return parameter |
|
1.0 |
1.0 |
KL loss coefficient |
|
— |
0.8 |
KL balancing weight (V2 only) |
Environment Interaction#
Parameter |
Default |
Effect |
|---|---|---|
|
2 |
Repeat each action N times |
|
0.3 |
Exploration noise std |
|
5000 |
Random steps before training |
|
5e6 |
Total environment steps |
|
1000 |
Steps between model updates |
Common Pitfalls#
Posterior collapse#
If the stochastic state is ignored by the dynamics, the model reduces to a deterministic RNN. Symptoms: good reconstruction but imagination diverges.
Fixes:
Increase
kl_loss_coeffor adjustkl_alpha(V2)Decrease
free_natsReduce
stoch_size
Imagination divergence#
The prior predicts states that drift from realistic latents over long horizons.
Fixes:
Keep
imagine_horizonshort (10–15)Verify multi-step prediction, not just one-step
NaN loss during training#
Fixes:
Reduce
model_learning_rateto 1e-4Tighten gradient clipping (default 100 → 10)
Enable layer norm
Actor never improves#
Fixes:
Increase
imagine_horizonfor delayed rewardsIncrease exploration noise via
action_noiseVerify critic loss is decreasing
References#
Hafner, D., Lillicrap, T., Fischer, I., Vuong, Q., Held, D., Haarnoja, T., & Abbeel, P. (2019). Dreamer: Learning Latent Dynamics for Planning from Pixels.
Hafner, D., Lillicrap, T., Ba, J., & Norouzi, M. (2020). Mastering Atari with Discrete World Models.
Hafner, D., Lillicrap, T., Norouzi, M., & Ba, J. (2021). Mastering Atari with Discrete World Models (DreamerV2). ICLR 2021.