Dreamer: Model-Based RL with Latent Dynamics#

Dreamer is a model-based reinforcement learning algorithm that learns a latent dynamics model from images and trains a behavior policy entirely in the latent space.

Based on papers:

Key Idea#

Dreamer learns:

  1. World Model: Latent dynamics model that predicts future latent states

  2. Value Model: Estimates expected returns from any latent state

  3. Policy: Actions that maximize expected returns in latent space

The key innovation is learning behaviors purely in imagination - no gradients flow from the environment.

Architecture#

World Model RSSM

Encoder CNN 64x64 GRU plus stochastic latent model Decoder transposed CNN

Imagination Rollout

State s0 Action a0 Imagined future states Lambda-return target

Actor-Critic Learning

Actor policy Critic value model

Components#

1. Recurrent State Space Model (RSSM)#

The core world model combining:

  • Deterministic hidden state (h_t): Recurrent state (GRU)

  • Stochastic latent state (s_t): Discrete or continuous latent variables

Dynamics:

\[\mathbf{h}_t = f(\mathbf{h}_{t-1}, \mathbf{s}_{t-1}, \mathbf{a}_{t-1})\]

Posterior:

\[\mathbf{s}_t \sim q(\mathbf{s}_t \mid \mathbf{h}_t, \mathbf{x}_t)\]

Prior:

\[\mathbf{s}_t \sim p(\mathbf{s}_t \mid \mathbf{h}_t)\]

2. Encoder/Decoder#

  • Encoder: CNN that maps images to latent embeddings

  • Decoder: Transposed CNN that reconstructs images from latents

  • Both use ReLU activations and residual connections

3. Reward/Discount Heads#

  • Reward model: Predicts reward from latent state

  • Discount model: Predicts episode termination (DreamerV2)

Training#

from torchwm import DreamerAgent
from torchwm import DreamerConfig

cfg = DreamerConfig()
cfg.env_backend = "gym"
cfg.env = "Pendulum-v1"
cfg.total_steps = 1_000_000

agent = DreamerAgent(cfg)
agent.train()

Key Hyperparameters#

Parameter

Default

Description

stoch_size

30

Stochastic latent dimensions

deter_size

200

Deterministic hidden size

embed_size

1024

Encoder embedding size

imagine_horizon

15

Imagination rollout length

discount

0.99

Discount factor γ

td_lambda

0.95

λ-return parameter

kl_loss_coeff

1.0

KL divergence weight

Learning Objectives#

World Model Loss:

\[\begin{aligned} \mathcal{L}_\mathrm{world} &= \mathcal{L}_\mathrm{reconstruction} + \mathcal{L}_\mathrm{reward} + \beta \cdot \mathcal{L}_\mathrm{KL} \end{aligned}\]

Actor Loss (REINFORCE):

\[\mathcal{L}_\mathrm{actor} = -\mathbb{E}\left[\log \pi(\mathbf{a} \mid \mathbf{s}) \cdot (G - V(\mathbf{s}))\right]\]

Critic Loss (MSE):

\[\mathcal{L}_\mathrm{critic} = \mathbb{E}[(G - V(\mathbf{s}))^2]\]

DreamerV2 Enhancements#

DreamerV2 introduces several improvements:

  1. Discrete latents: Categorical latent variables instead of Gaussian

  2. KL balancing: Separate weighting for prior/posterior KL

  3. Discount model: Learns to predict episode termination

  4. Layer normalization: More stable training

Environment Support#

Dreamer supports multiple backends:

cfg = DreamerConfig()
cfg.env_backend = "dmc"      # DeepMind Control Suite
cfg.env = "walker-walk"

cfg.env_backend = "gym"      # Gym/Gymnasium
cfg.env = "Pendulum-v1"

# MuJoCo example:
cfg.env_backend = "mujoco"   # MuJoCo task ids or native MJCF/MJB files
cfg.env = "Humanoid-v4"      # or "models/cartpole.xml"
cfg.mujoco_camera = None     # native MJCF/MJB only
cfg.mujoco_frame_skip = 4    # native MJCF/MJB only

# Gymnasium Robotics example (all ids registered by installed package):
cfg.env_backend = "robotics"
cfg.env = "HalfCheetah-v2"

# Brax example:
cfg.env_backend = "brax"     # JAX/Brax
cfg.env = "ant"
cfg.brax_backend = "generalized"

cfg.env_backend = "unity_mlagents"  # Unity ML-Agents
cfg.unity_file_name = "env.exe"

For MuJoCo tasks, Dreamer delegates adapter construction to make_mujoco_env_from_config, which keeps make_env focused on backend selection while the MuJoCo module owns task-id vs XML/MJB source selection. Use Gymnasium task ids such as Humanoid-v4 for standard benchmark rewards, or use native MJCF/MJB sources plus MuJoCoImageEnv callbacks for custom rewards and termination logic. Legacy MuJoCo v2/v3 ids and other Gymnasium Robotics tasks can use env_backend="robotics"; TorchWM lists those ids dynamically from the installed gymnasium-robotics package.

References#

  • Hafner, D., Lillicrap, T., Fischer, I., Vuong, Q., Held, D., Haarnoja, T., & Abbeel, P. (2019). Dreamer: Learning Latent Dynamics for Planning from Pixels.

  • Hafner, D., Lillicrap, T., Ba, J., & Norouzi, M. (2020). Mastering Atari with Discrete World Models.