Source code for world_models.models.dreamer_rssm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions
_str_to_activation = {
"relu": nn.ReLU(),
"elu": nn.ELU(),
"tanh": nn.Tanh(),
"leaky_relu": nn.LeakyReLU(),
"sigmoid": nn.Sigmoid(),
"selu": nn.SELU(),
"softplus": nn.Softplus(),
"identity": nn.Identity(),
}
[docs]
class RSSM(nn.Module):
"""Recurrent State-Space Model used by Dreamer for latent dynamics learning.
The RSSM is the core world model component that learns compact representations
of environment dynamics. It maintains a hybrid state consisting of:
1. **Deterministic State (h)**: A recurrent hidden state updated by a GRU,
capturing sequential/temporal information and deterministic transitions.
2. **Stochastic State (s)**: A latent variable representing stochastic,
multi-modal uncertainty in the environment (e.g., ambiguous observations).
The model operates in two modes:
- **Observe Mode**: Updates states using actual observations from the environment.
Uses the representation model: p(s_t | h_t, obs_t)
- **Imagine Mode**: Predicts future states without observations.
Uses the transition/prior model: p(s_t | h_t)
Architecture:
Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1}
Process: GRU updates deterministic state, MLP computes stochastic prior/posterior
Output: Updated state (h_t, s_t) and distributions
State Representation:
- deter (h): GRU hidden state, captures sequential context
- stoch (s): Stochastic latent, multi-modal uncertainty
- mean/std: Parameters of the stochastic distribution
Usage with DreamerAgent:
rssm = RSSM(
action_size=action_dim,
stoch_size=30, # Stochastic state dimension
deter_size=200, # Deterministic (GRU) state dimension
hidden_size=200, # MLP hidden layer size
obs_embed_size=256, # Observation embedding from encoder
activation='elu'
)
# Observe with actual observation
posterior = rssm.observe_step(prev_state, prev_action, obs_embed)
# Imagine future without observation
prior = rssm.imagine_step(current_state, action)
Training:
The RSSM is trained by maximizing the ELBO (Evidence Lower Bound):
- KL divergence between prior and posterior encourages the prior to
capture environment dynamics
- Reconstruction loss from decoder ensures state captures observation info
Reference:
Dreamer: Scalable Reinforcement Learning Using World Models
Hafner et al., 2020 - https://arxiv.org/abs/1912.01603
"""
def __init__(
self,
action_size,
stoch_size,
deter_size,
hidden_size,
obs_embed_size,
activation,
):
super().__init__()
self.action_size = action_size
self.stoch_size = stoch_size
self.deter_size = deter_size
self.hidden_size = hidden_size
self.embedding_size = obs_embed_size
self.act_fn = _str_to_activation[activation]
self.rnn = nn.GRUCell(self.deter_size, self.deter_size)
self.fc_state_action = nn.Linear(
self.stoch_size + self.action_size, self.deter_size
)
self.fc_embed_prior = nn.Linear(self.deter_size, self.hidden_size)
self.fc_state_prior = nn.Linear(self.hidden_size, 2 * self.stoch_size)
self.fc_embed_posterior = nn.Linear(
self.embedding_size + self.deter_size, self.hidden_size
)
self.fc_state_posterior = nn.Linear(self.hidden_size, 2 * self.stoch_size)
[docs]
def init_state(self, batch_size, device):
"""Initialize RSSM state with zeros.
Args:
batch_size: Number of parallel sequences
device: torch device for tensors
Returns:
Dictionary containing zero-initialized state components:
- mean, std: Stochastic distribution parameters
- stoch: Stochastic state sample
- deter: Deterministic GRU hidden state
"""
return dict(
mean=torch.zeros(batch_size, self.stoch_size).to(device),
std=torch.zeros(batch_size, self.stoch_size).to(device),
stoch=torch.zeros(batch_size, self.stoch_size).to(device),
deter=torch.zeros(batch_size, self.deter_size).to(device),
)
[docs]
def get_dist(self, mean, std):
"""Create an Independent Normal distribution from mean and std.
Args:
mean: Location parameter
std: Scale parameter
Returns:
Independent Normal distribution with given parameters
"""
distribution = distributions.Normal(mean, std)
distribution = distributions.independent.Independent(distribution, 1)
return distribution
def _gru_input(self, prev_state, prev_action, nonterm):
"""Project [action, stoch] into the GRU input space and apply nonterm.
Per the Danijar Dreamer reference, the previous stochastic state is
masked by `nonterm` before being concatenated with the action, and the
previous deterministic state is also masked by `nonterm` when fed back
to the GRU. This ensures that the state is reset (rather than
propagated) at episode boundaries.
"""
prev_stoch_masked = prev_state["stoch"] * nonterm
x = torch.cat([prev_action, prev_stoch_masked], dim=-1)
x = self.act_fn(self.fc_state_action(x))
return x
[docs]
def observe_step(self, prev_state, prev_action, obs_embed, nonterm=1.0):
"""Update state using actual observation (observe mode).
In observe mode, the RSSM first computes a transition prior from the
previous state and action, then refines the stochastic state using the
actual observation embedding to form the posterior.
Args:
prev_state: Dictionary with 'deter' (h_{t-1}) and 'stoch' (s_{t-1})
prev_action: Previous action a_{t-1}, shape (B, action_size)
obs_embed: Observation embedding from encoder, shape (B, obs_embed_size)
nonterm: Termination mask (1.0 = continue, 0.0 = terminal)
Returns:
A tuple ``(posterior, prior)`` of state dictionaries. The posterior
incorporates observation information; the prior is the transition
prediction before observation. Both share the same deterministic
state because the GRU is only advanced once per timestep.
"""
prior = self.imagine_step(prev_state, prev_action, nonterm)
posterior_embed = self.act_fn(
self.fc_embed_posterior(torch.cat([obs_embed, prior["deter"]], dim=-1))
)
posterior = self.fc_state_posterior(posterior_embed)
mean, std = torch.chunk(posterior, 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
posterior_state = dict(mean=mean, std=std, stoch=sample, deter=prior["deter"])
return posterior_state, prior
[docs]
def imagine_step(self, prev_state, prev_action, nonterm=1.0):
"""Predict next state without observation (imagine mode).
In imagine mode, the RSSM predicts future states using only the prior
distribution. This is used for planning and policy learning where
actual observations are not available.
Args:
prev_state: Dictionary with 'deter' (h_{t-1}) and 'stoch' (s_{t-1})
prev_action: Previous action a_{t-1}, shape (B, action_size)
nonterm: Termination mask (1.0 = continue, 0.0 = terminal)
Returns:
Dictionary with predicted state containing:
- deter: Predicted deterministic state
- mean, std, stoch: Prior stochastic state distribution
"""
x = self._gru_input(prev_state, prev_action, nonterm)
prior_deter = self.rnn(x, prev_state["deter"] * nonterm)
prior_embed = self.act_fn(self.fc_embed_prior(prior_deter))
prior = self.fc_state_prior(prior_embed)
mean, std = torch.chunk(prior, 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
return dict(mean=mean, std=std, stoch=sample, deter=prior_deter)
[docs]
def get_prior(self, prev_state, prev_action, nonterm=1.0):
"""Compute prior distribution over stochastic state.
The prior represents the model's belief about the stochastic state
before observing the actual outcome.
Args:
prev_state: Previous state dictionary
prev_action: Previous action
nonterm: Termination mask
Returns:
Dictionary with prior state (no observation information)
"""
return self.imagine_step(prev_state, prev_action, nonterm)
[docs]
def get_posterior(self, prev_state, prev_action, obs_embed, nonterm=1.0):
"""Compute posterior distribution over stochastic state.
The posterior incorporates observation information to produce
a more accurate state estimate.
Args:
prev_state: Previous state dictionary
prev_action: Previous action
obs_embed: Observation embedding
nonterm: Termination mask
Returns:
Dictionary with posterior state (observation-informed). Note that
the previous-state shape ``(B, ...)`` is preserved; the batch
dimension is not flattened.
"""
posterior, _ = self.observe_step(prev_state, prev_action, obs_embed, nonterm)
return posterior
[docs]
def detach_state(self, state):
"""Detach state tensors from computation graph.
Used during DreamerV2 training to prevent gradient flow through
the observation/update pathway.
Args:
state: State dictionary with tensor values
Returns:
Detached state dictionary
"""
return {k: v.detach() for k, v in state.items()}
[docs]
def seq_to_batch(self, state_dict):
"""Convert sequence state to batch format.
Args:
state_dict: Dictionary with sequence-dimension tensors (T, B, ...)
Returns:
Dictionary with batch-dimension tensors (B*T, ...)
"""
return {k: v.reshape(-1, *v.shape[2:]) for k, v in state_dict.items()}
[docs]
def observe_rollout(self, obs_embed, actions, nonterms, init_state, seq_len):
"""Process a sequence of observations (observe mode rollout).
At each timestep we run ``observe_step`` once to obtain the transition
prior (the prediction given the previous state and action) and the
observation-informed posterior. The posterior is then used as the
previous state for the next step, matching the standard Dreamer
inference pattern.
Args:
obs_embed: Observation embeddings, shape (T+1, B, obs_embed_size)
actions: Actions, shape (T, B, action_size)
nonterms: Non-termination flags, shape (T, B, 1)
init_state: Initial state dictionary
seq_len: Sequence length T
Returns:
prior: Dictionary with prior states stacked along the time axis.
posterior: Dictionary with posterior states stacked along the time
axis.
"""
prior_states = []
posterior_states = []
state = init_state
for t in range(seq_len):
posterior, prior = self.observe_step(
state, actions[t], obs_embed[t], nonterms[t]
)
posterior_states.append(posterior)
prior_states.append(prior)
state = posterior
to_stack = ["mean", "std", "stoch", "deter"]
prior = {k: torch.stack([p[k] for p in prior_states], dim=0) for k in to_stack}
posterior = {
k: torch.stack([p[k] for p in posterior_states], dim=0) for k in to_stack
}
return prior, posterior
[docs]
def imagine_rollout(self, policy, init_state, horizon):
"""Generate imagined trajectory using policy (imagine mode rollout).
Args:
policy: Actor network that outputs actions from state features
init_state: Initial state dictionary
horizon: Number of steps to imagine
Returns:
Dictionary with imagined states for each step
"""
states = []
state = init_state
for _ in range(horizon):
features = torch.cat([state["stoch"], state["deter"]], dim=-1)
action = policy(features, deter=False)
state = self.imagine_step(state, action)
states.append(state)
to_stack = ["mean", "std", "stoch", "deter"]
return {k: torch.stack([s[k] for s in states], dim=0) for k in to_stack}
[docs]
def forward(self, x, u):
"""Forward pass for training (computes sequence of states).
Args:
x: Observations, shape (B, T+1, C, H, W)
u: Actions, shape (B, T, action_size)
Returns:
states: List of state dictionaries for each timestep
priors: List of prior distributions (tuples of mean, std)
posteriors: List of posterior distributions (tuples of mean, std)
"""
B = x.size(0)
T = u.size(1)
priors = []
posteriors = []
state = self.init_state(B, x.device)
for t in range(T):
prior = self.get_prior(state, u[:, t])
priors.append((prior["mean"], prior["std"]))
obs_embed = x[:, t + 1].reshape(B, -1)
posterior = self.get_posterior(state, u[:, t], obs_embed)
posteriors.append((posterior["mean"], posterior["std"]))
state = posterior
states = None
return states, priors, posteriors