World Models Deep Dive (Ha & Schmidhuber, 2018)#
This page is a comprehensive technical reference for the Ha & Schmidhuber World Models implementation in TorchWM. It covers the architecture, data pipeline, training details, inference, configuration, and common pitfalls.
Architecture Overview#
The world model consists of three independently trained components:
graph LR
A["Raw pixels 64×64×3"] --> B["V: ConvVAE encoder"]
B --> C["Latent z (32-d)"]
C --> D["M: MDN-RNN"]
D --> E["Hidden h + predicted z"]
C --> F["C: Linear controller"]
E --> F
F --> G["Action a"]
G --> H["Environment"]
H --> A
A --> B
Stage |
Component |
Function |
Trained With |
File |
|---|---|---|---|---|
V |
ConvVAE |
Encodes 64×64 RGB → 32-d latent |
Reconstruction loss (MSE + KL) |
|
M |
MDN-RNN |
Predicts next latent as Gaussian mixture `p(zₜ₊₁ |
aₜ,zₜ,hₜ)` |
GMM NLL + BCE + MSE |
C |
Linear Controller |
Maps |
CMA-ES (reward maximization) |
|
Key design decisions#
Latent compression: The VAE compresses 64×64×3 = 12,288 pixels into 32 floats. This makes the MDRNN’s GMM output tractable and the controller’s parameter count tiny (~10³), enabling black-box optimization with CMA-ES.
Gaussian mixture output: The MDRNN predicts the next latent as a mixture of Gaussians rather than a single Gaussian. This captures multimodal futures (e.g., “turn left” vs. “turn right” from the same state).
Hidden state is critical: The controller receives both the latent
zand the RNN hidden stateh. The paper shows removinghdrops CarRacing score from 906±21 → 632±251, confirming that temporal memory is essential.CMA-ES over backprop: The controller is trained with evolution strategies rather than gradient descent. This avoids differentiating through the environment or the world model, and works on sparse/delayed rewards.
Stage 1: Vision — ConvVAE#
Model architecture#
The ConvVAE follows a standard convolutional encoder-decoder structure:
Encoder: Decoder:
3×64×64 input 32-d latent
└─ Conv2D(3, 32, 4, stride=2) └─ Linear(32, 1024)
└─ Conv2D(32, 64, 4, stride=2) └─ ConvTranspose2d(64, 64, 5, stride=2)
└─ Conv2D(64, 128, 4, stride=2) └─ ConvTranspose2d(64, 32, 5, stride=2)
└─ Conv2D(128, 256, 4, stride=2) └─ ConvTranspose2d(32, 32, 6, stride=2)
└─ Flatten → Linear(1024, 2×latent_size) └─ ConvTranspose2d(32, 3, 6, stride=2)
└─ Returns (mu, logsigma) └─ Sigmoid → 3×64×64 output
Key classes in world_models.vision.VAE.ConvVAE:
ConvVAEEncoder: Encodes images →(mu, logsigma).ConvVAEDecoder: Decodes latent → reconstructed image.ConvVAE: Combines encoder + decoder, exposesforward(x)→(recon, mu, logsigma).
Loss function#
Defined in world_models.losses.convae_loss:
MSE term: Measures pixel-level reconstruction quality.
KL term: Regularizes the latent distribution toward a standard normal prior. When
μ=0andlogσ=0, the KL term is exactly zero, and the loss equals the reconstruction loss alone.The
size_average=Falseflag on MSE means the loss is summed, not averaged over pixels. This is intentional: the KL term is also summed over latent dimensions. Both terms are on similar scales after summation.
Data pipeline#
Training data is collected by running the environment with random actions:
from world_models.training.train_world_model import generate_rollouts
generate_rollouts(
data_dir="./data/carracing",
env_name="CarRacing-v2",
num_rollouts=1000,
seq_len=1000,
num_workers=8,
)
Each rollout is saved as a .npz file containing:
observations:(T, 64, 64, 3)uint8 arrayactions:(T, action_size)float32 arrayrewards:(T,)float32 arrayterminals:(T,)float32 array
Dataset classes for VAE training:
ObservationDataset: Returns individual frames (not sequences). Used bytrain_convvae.py. ExtendsRolloutDatasetand overrides_get_data()to return only the observation tensor.RolloutDataset: Base class that loads.npzfiles, manages a circular buffer of open files, and splits files into train/test sets. Each sample returns adict(observation, action, reward, terminal).SequenceDataset: Used for MDRNN training with raw images (when precomputed latents are disabled). Returns sequences of observations, actions, rewards, terminals, and next observations.LatentSequenceDataset: Used for MDRNN training with precomputed latents. Operates on pre-encoded numpy arrays rather than raw images, reducing memory and avoiding repeated VAE encoding.
Train/test split#
The num_test_files parameter controls how many .npz files are reserved for
validation:
dataset = ObservationDataset(
root="./data",
train=True,
num_test_files=600, # last 600 files → test set
)
When len(files) > num_test_files and num_test_files > 0, the last
num_test_files files are used for test. If there aren’t enough files, all
data goes to training.
Warning
Passing num_test_files=0 was historically broken: files[:-0] evaluates to
files[:0] (empty list) in Python due to -0 == 0. This was fixed by guarding
the split with num_test_files > 0.
Training loop#
See world_models.training.train_convvae.train_convae():
Load pretrained VAE checkpoint if available (
noreload=False).Create
ObservationDatasetwithA.Composetransforms (resize + optional flip).Train for
num_epochsusingAdam(..., lr=learning_rate).Validate after each epoch.
Reduce LR on plateau via
ReduceLROnPlateau.Early stop via
EarlyStopping.Save best checkpoint as
best.tar, current ascheckpoint.tar.Generate sample images every
sample_intervalepochs.
Note
The VAE is compiled with torch.compile when CUDA is available for faster
training. This can be disabled if compatibility issues arise.
Stage 2: Memory — MDN-RNN#
Why a Gaussian mixture?#
A single Gaussian assumes the next latent follows a unimodal distribution, but many environments are fundamentally multimodal. From the same state, different actions lead to different futures, and even the same action may have stochastic outcomes. The Mixture Density Network (MDN) handles this by predicting:
where K is the number of Gaussian components (typically 5), π_k are the
mixture weights, and (μ_k, σ_k) are the component parameters.
MDRNN vs MDRNNCell#
The module world_models.models.mdrnn provides two variants:
Class |
RNN Type |
Use Case |
Forward Signature |
|---|---|---|---|
|
|
Training — processes full sequences at once |
|
|
|
Inference — one step at a time |
|
Both share the same _MDRNNBase parent, which defines the gmm_linear output
layer. The output head maps the RNN hidden state to the GMM parameters:
self.gmm_linear = nn.Linear(hiddens, (2 * latents + 1) * gaussians + 2)
This produces:
gaussians × latentsmeans (mus)gaussians × latentssigmas (raw, then exponentiated)gaussianslogits (softmax → logpi)1 reward logit
1 terminal logit
Weight transfer for inference#
Training uses MDRNN (batched LSTM). Inference uses MDRNNCell (single-step
LSTMCell). The weights must be copied between the two:
batch_rnn = MDRNN(latents=32, actions=3, hiddens=256, gaussians=5)
batch_rnn.load_state_dict(torch.load("mdrnn_best.tar")["state_dict"])
cell_rnn = MDRNNCell(latents=32, actions=3, hiddens=256, gaussians=5)
cell_rnn.rnn.weight_ih.data.copy_(batch_rnn.rnn.weight_ih_l0.data)
cell_rnn.rnn.weight_hh.data.copy_(batch_rnn.rnn.weight_hh_l0.data)
cell_rnn.rnn.bias_ih.data.copy_(batch_rnn.rnn.bias_ih_l0.data)
cell_rnn.rnn.bias_hh.data.copy_(batch_rnn.rnn.bias_hh_l0.data)
cell_rnn.gmm_linear.load_state_dict(batch_rnn.gmm_linear.state_dict())
The LSTM’s weight_ih_l0 maps to LSTMCell’s weight_ih, weight_hh_l0 to
weight_hh, and similarly for biases. The gmm_linear weights are shared
identically via load_state_dict.
Important
The initial hidden state for MDRNNCell must be created via get_init_hidden()
and updated on every step. A common bug is to reuse the same initial hidden
state, which causes the cell to “restart” from zeros each time, destroying
temporal memory.
Loss function#
Computed in world_models.training.train_mdn_rnn.get_loss():
GMM loss (
world_models.losses.gmm_loss): Negative log-likelihood of the observed next latent under the predicted Gaussian mixture. Uses numerically stable log-sum-exp:\[\mathcal{L}_{\text{GMM}} = -\log\sum_k \pi_k \cdot \mathcal{N}(z_{t+1} | \mu_k, \sigma_k)\]BCE loss: Binary cross-entropy for terminal flag prediction.
MSE loss: Mean squared error for reward prediction (only if
include_reward=True).Scaling factor: The total loss is divided by
latent_size + 2(orlatent_size + 1if reward is excluded) to balance the GMM loss scale.
Precomputed latents#
To avoid encoding every batch through the VAE during MDRNN training (which is
memory-intensive), train_mdn_rnn.py supports precomputed latents:
Precomputation:
precompute_latents()loads the trained VAE, encodes all observations in all rollouts, and saves the result as a single.npzfile:latent_data = np.load("data/carracing/latents/latents_32.npz") latent_data.keys() # latents, actions, rewards, terminals
Training with precomputed latents: Uses
LatentSequenceDatasetwhich operates directly on the numpy arrays without any VAE encoding during training.Training without precomputed latents: Uses
SequenceDatasetwhich returns raw image sequences. Thedata_pass()function encodes them on the fly viato_latent(). This is slower but requires less disk space.
Stage 3: Controller — CMA-ES#
Linear controller#
Defined in world_models.models.controller.Controller:
class Controller(nn.Module):
def __init__(self, latent_size, hidden_size, action_size):
self.fc = nn.Linear(latent_size + hidden_size, action_size)
def forward(self, state):
return self.fc(state) # state = concat([z, h])
The controller is a single linear layer with no activation, bias included.
Input is the concatenation of the latent vector z and RNN hidden state h.
Output is the action vector (clamped to [-1, 1] by the environment wrapper).
Total parameters: (latent_size + hidden_size + 1) × action_size. For
CarRacing (32 + 256 + 1) × 3 = 867 parameters.
CMA-ES optimization#
The controller is trained with Covariance Matrix Adaptation Evolution Strategy
(CMA-ES) via the cma package, not gradient descent:
Initialize: Start with a mean parameter vector (from random init) and a covariance matrix.
Sample: Generate a population of candidate parameter vectors from the current distribution.
Evaluate: For each candidate, run a full episode rollout in the environment and collect the total reward. This is done in parallel across multiple worker processes.
Update: The CMA-ES algorithm adjusts the mean and covariance toward regions that produced higher rewards.
Repeat until the target return is reached or convergence.
Parallel evaluation#
train_controller.py uses torch.multiprocessing for parallel rollout:
Master process: Runs CMA-ES, dispatches parameter vectors to workers.
Worker processes (
slave_routine): Each loads the trained VAE + MDRNNCell, creates an environment, and runs rollouts with the received controller parameters.Weight transfer in workers: Each worker converts the batched
MDRNNcheckpoint to anMDRNNCellfor recurrent inference:
def _run_rollout(ctrl_params, logdir, env_name, action_size, time_limit, device):
# Load VAE and MDRNN checkpoints, convert to MDRNNCell
# Create environment
# Run episode with controller(h, z) → action → step → update h
return total_reward
Note
Each worker has its own copy of the VAE and MDRNNCell on its assigned GPU.
The torch.multiprocessing spawn model ensures clean CUDA context per process.
The flatten_parameters / load_parameters utilities convert between the
Controller’s nn.Parameter tensors and flat numpy arrays for CMA-ES.
Sampling for robust evaluation#
Each candidate controller is evaluated n_samples times (default 4) with
different random seeds, and the rewards are averaged. This reduces the variance
from stochastic environment dynamics and the VAE’s random sampling (z = μ + σ · ε).
Inference Pipeline#
The complete inference loop (from test_trained_model() in
train_world_model.py):
# 1. Load models
vae = ConvVAE(img_channels=3, latent_size=32)
vae.load_state_dict(torch.load("vae/best.tar")["state_dict"])
cell_rnn = MDRNNCell(latents=32, actions=3, hiddens=256, gaussians=5)
# Transfer weights from trained MDRNN (see weight transfer section above)
ctrl = Controller(latent_size=32, hidden_size=256, action_size=3)
ctrl.load_state_dict(torch.load("ctrl/best.tar")["state_dict"])
# 2. Reset environment and hidden state
obs, _ = env.reset()
h, c = cell_rnn.get_init_hidden(1) # fresh hidden state per episode
# 3. Rollout loop
for step in range(1000):
# Encode observation → latent
mu, logsigma = vae.encoder(preprocess(obs))
z = mu + logsigma.exp() * torch.randn_like(logsigma)
# Controller computes action from (hidden, latent)
action = ctrl(h, z).cpu().numpy().flatten()
# Step environment
next_obs, reward, done, _ = env.step(action)
# Update RNN hidden state with (action, latent)
_, _, _, _, _, (h, c) = cell_rnn(tensor(action), z, (h, c))
obs = next_obs
if done:
break
Configuration Reference#
WMVAEConfig#
Fields marked with ✓ have defaults, fields marked with · are required.
Field |
Type |
Default |
Description |
|---|---|---|---|
|
|
— |
Input image height (pixels) |
|
|
— |
Input image width (pixels) |
|
|
|
Dimensionality of VAE latent space |
|
|
|
Training device |
|
|
|
Samples per training batch |
|
|
|
Number of training epochs |
|
|
|
Path to rollout data (.npz files) |
|
|
|
Adam learning rate |
|
|
|
Checkpoint and log directory |
|
|
|
Skip loading existing checkpoints |
|
|
|
Skip saving sample images |
|
|
|
LR scheduler patience (epochs) |
|
|
|
LR multiplicative factor |
|
|
|
Early stopping patience |
|
|
|
Epochs between sample saves |
|
|
|
Additional custom parameters |
WMMDNRNNConfig#
Field |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Latent space dimensionality |
|
|
|
Action space dimensionality |
|
|
|
RNN hidden units |
|
|
|
GMM mixture components |
|
|
|
Training device |
|
|
|
Sequences per batch |
|
|
|
Sequence length per sample |
|
|
|
Training epochs |
|
|
|
Rollout data path |
|
|
|
RMSprop learning rate |
|
|
|
Checkpoint directory |
|
|
|
Skip loading checkpoints |
|
|
|
Include reward in loss |
|
|
|
LR scheduler patience |
|
|
|
LR multiplicative factor |
|
|
|
Early stopping patience |
|
|
|
Additional custom parameters |
WMControllerConfig#
Field |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Latent space dimensionality |
|
|
|
RNN hidden dimensionality |
|
|
|
Action space dimensionality |
|
|
|
Environment for evaluation |
|
|
|
Checkpoint directory |
|
|
|
Rollout samples per candidate |
|
|
|
CMA-ES population size |
|
|
|
Stop when return ≥ target |
|
|
|
Max parallel workers |
|
|
|
Show progress bars |
|
|
|
Max steps per episode |
|
|
|
Additional custom parameters |
Loss Function Details#
conv_vae_loss_fn#
Where:
Note
Both terms use sum reduction (not mean). With latent_size=32 and
image size 64×64, the KL term sums 32 values, while MSE sums 3×64×64 values.
The reconstruction term therefore dominates numerically.
gmm_loss#
Implementation uses numerically stable log-space computation:
normal_dist = Normal(mus, sigmas)
g_log_probs = logpi + normal_dist.log_prob(latent_next_obs).sum(dim=-1)
max_log_probs = g_log_probs.max(dim=-1, keepdim=True)[0]
g_log_probs = g_log_probs - max_log_probs # stabilize
log_prob = max_log_probs.squeeze() + torch.log(torch.exp(g_log_probs).sum(dim=-1))
return -log_prob.mean() # negative log-likelihood
The max_log_probs subtraction prevents numerical overflow in the exp sum.
Testing Guide#
The test suite covers all components with 83 tests across 8 files:
Test file |
Tests |
What it covers |
|---|---|---|
|
10 |
ConvVAE forward shapes, gradient flow, reconstruction range, latent size variants |
|
16 |
MDRNN/MDRNNCell shapes, differentiability, hidden state updates, weight transfer |
|
5 |
Controller shapes, inference modes, differentiability |
|
11 |
Config creation, defaults, validation, extra key proxying, serialization |
|
6 |
Loss scalar, positivity, reconstruction ordering, differentiability, KL behavior |
|
6 |
Loss scalar, positivity, prediction ordering, differentiability, GMM variants |
|
17 |
EarlyStopping, ReduceLROnPlateau modes, state dict roundtrip |
|
12 |
RolloutDataset, ObservationDataset, SequenceDataset, LatentSequenceDataset |
Run all tests:
python -m pytest tests/vision/test_convvae.py tests/models/test_mdrnn.py \
tests/models/test_controller_wm.py tests/configs/test_wm_config.py \
tests/losses/test_convae_loss.py tests/losses/test_gmm_loss.py \
tests/utils/test_train_utils.py tests/datasets/test_wm_dataset.py -v
Notable test patterns#
Testing differentiable gradient flow:
def test_differentiable(self, model):
actions = torch.randn(seq_len, bs, 3, requires_grad=True)
latents = torch.randn(seq_len, bs, 32, requires_grad=True)
mus, sigmas, logpi, rs, ds = model(actions, latents)
loss = mus.sum() + sigmas.sum() + logpi.sum() + rs.sum() + ds.sum()
loss.backward()
for name, param in model.named_parameters():
assert param.grad is not None, f"{name} has no gradient"
Testing weight transfer (MDRNN → MDRNNCell):
def test_weight_transfer_from_mdrnn(self):
batch_rnn = MDRNN(latents=32, actions=3, hiddens=256, gaussians=5)
cell_rnn = MDRNNCell(latents=32, actions=3, hiddens=256, gaussians=5)
# Copy LSTM weights
cell_rnn.rnn.weight_ih.data.copy_(batch_rnn.rnn.weight_ih_l0.data)
cell_rnn.rnn.weight_hh.data.copy_(batch_rnn.rnn.weight_hh_l0.data)
cell_rnn.rnn.bias_ih.data.copy_(batch_rnn.rnn.bias_ih_l0.data)
cell_rnn.rnn.bias_hh.data.copy_(batch_rnn.rnn.bias_hh_l0.data)
cell_rnn.gmm_linear.load_state_dict(batch_rnn.gmm_linear.state_dict())
# Compare outputs step-by-step (MUST update hidden state!)
h, c = cell_rnn.get_init_hidden(bs)
cell_outs = []
for t in range(seq_len):
out = cell_rnn(actions[t], latents[t], (h, c))
cell_outs.append(out)
_, _, _, _, _, (h, c) = out # ← critical: update h, c
assert torch.allclose(cell_outs[t][0], mus_batch[t], atol=1e-4, rtol=1e-3)
Testing non-leaf gradient (gmm_loss differentiable):
When a loss function takes logpi that was computed with .log_softmax(), the
logpi tensor is not a leaf. To verify gradient flow, create and check a leaf:
def test_differentiable(self):
logpi_raw = torch.randn(2, 5, 5, requires_grad=True)
logpi = logpi_raw.log_softmax(dim=-1)
loss = gmm_loss(batch, mus, sigmas, logpi)
loss.backward()
assert logpi_raw.grad is not None # ✅ leaf tensor
# logpi.grad would be None ⚠️ (non-leaf)
CLI Usage#
The unified training script world_models.training.train_world_model provides a
complete CLI:
# Generate data + train all 3 stages
python -m world_models.training.train_world_model --env CarRacing-v2
# Train only specific stages
python -m world_models.training.train_world_model --env CarRacing-v2 --stage vae
python -m world_models.training.train_world_model --env CarRacing-v2 --stage rnn
python -m world_models.training.train_world_model --env CarRacing-v2 --stage ctrl
# Generate rollouts only
python -m world_models.training.train_world_model --env CarRacing-v2 --generate_only
# Test a trained model
python -m world_models.training.train_world_model --env CarRacing-v2 --test
# With custom directories
python -m world_models.training.train_world_model \
--env CarRacing-v2 \
--data_dir ./data/carracing \
--logdir ./results/carracing \
--latent_size 32 \
--rnn_hidden 256 \
--vae_epochs 50 \
--rnn_epochs 30 \
--ctrl_pop_size 16
Common Pitfalls#
2. files[:-0] returns empty list#
# ❌ When num_test_files=0:
self.files = self.files[:-num_test_files] # evaluates to self.files[:0] = []
Fix: Guard with num_test_files > 0.
3. Mismatched dataloader parallelism#
ObservationDataset and SequenceDataset use a circular buffer pattern
(load_next_buffer()) that is incompatible with num_workers > 0 in the
DataLoader. Always use num_workers=0 for these datasets.
LatentSequenceDataset does not use a circular buffer and can use
num_workers=4, pin_memory=True.
4. VAE reconstruction + KL scale imbalance#
With size_average=False on MSE, the reconstruction loss is summed over
3 × 64 × 64 = 12,288 pixels while KL sums over latent_size = 32 dimensions.
The reconstruction term is ~384× larger. If you add a β-VAE weighting,
use β = latent_size / (3 * height * width) to balance the scales.
5. MDRNN loss scaling factor#
The GMM loss operates on latent_size dimensions and produces values on the
order of latent_size. The total loss divides by latent_size + 2 (or
latent_size + 1 without reward) to keep the scale around 1. If you change
latent_size, the loss scale changes proportionally.
6. CMA-ES workers and GPU memory#
Each worker process loads a full copy of the VAE and MDRNNCell. On GPU, this
can exhaust memory with many workers. The max_workers config caps this,
and workers are distributed across available GPUs.
Extending the World Model#
Adding a new environment#
Ensure the environment conforms to Gymnasium’s API (
reset,step,action_space,observation_space).Set
env_nameinWMControllerConfig.The
GymImageEnvwrapper handles image resizing to 64×64 and uint8 conversion.
Changing the VAE architecture#
Override the ConvVAEEncoder / ConvVAEDecoder classes. The encoder must
return (mu, logsigma) with logsigma.shape == mu.shape. The decoder
must return a tensor matching the input image shape.
Changing the RNN cell#
Replace nn.LSTM / nn.LSTMCell in MDRNN.__init__ / MDRNNCell.__init__.
The hidden state (h, c) convention must be preserved for weight transfer.
Adding a non-linear controller#
Replace nn.Linear with an MLP in Controller.__init__. Note that CMA-ES
scales poorly with parameter count: an MLP with 2 hidden layers of 64 units
adds (32+256)×64 + 64×64 + 64×3 ≈ 22,000 parameters vs. 867 for linear.
Consider using a smaller population size or switching to backprop-based
controller training for larger networks.