Source code for world_models.models.rssm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from world_models.vision.planet_encoder import CNNEncoder
from world_models.vision.planet_decoder import CNNDecoder


[docs] class RecurrentStateSpaceModel(nn.Module): """ A Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data. """ def __init__( self, action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function="relu", ): super().__init__() self.state_size = state_size self.action_size = action_size self.latent_size = latent_size self.act_fn = getattr(F, activation_function) self.encoder = CNNEncoder(embed_size) self.decoder = CNNDecoder(state_size, latent_size, embed_size) self.grucell = nn.GRUCell(state_size, state_size) self.lat_act_layer = nn.Linear(latent_size + action_size, state_size) self.fc_prior_1 = nn.Linear(state_size, hidden_size) self.fc_prior_m = nn.Linear(hidden_size, latent_size) self.fc_prior_s = nn.Linear(hidden_size, latent_size) self.fc_posterior_1 = nn.Linear(state_size + embed_size, hidden_size) self.fc_posterior_m = nn.Linear(hidden_size, latent_size) self.fc_posterior_s = nn.Linear(hidden_size, latent_size) self.fc_reward_1 = nn.Linear(state_size + latent_size, hidden_size) self.fc_reward_2 = nn.Linear(hidden_size, hidden_size) self.fc_reward_3 = nn.Linear(hidden_size, 1)
[docs] def get_init_state(self, enc, h_t=None, s_t=None, a_t=None, mean=False): """Returns the initial posterior given the observation.""" N, dev = enc.size(0), enc.device h_t = torch.zeros(N, self.state_size).to(dev) if h_t is None else h_t s_t = torch.zeros(N, self.latent_size).to(dev) if s_t is None else s_t a_t = torch.zeros(N, self.action_size).to(dev) if a_t is None else a_t h_tp1 = self.deterministic_state_fwd(h_t, s_t, a_t) if mean: s_tp1 = self.state_posterior(h_t, enc, sample=True) else: s_tp1, _ = self.state_posterior(h_t, enc) return h_tp1, s_tp1
[docs] def deterministic_state_fwd(self, h_t, s_t, a_t): """ Deterministic transition update that accepts: - a_t shaped [B, action_size] - a_t shaped [action_size] (unbatched) -> expanded to [B, action_size] - a_t shaped [B] or scalar -> reshaped appropriately Ensures a_t is 2D and matches batch dimension of h_t before concatenation. """ # ensure torch tensor if not isinstance(a_t, torch.Tensor): a_t = torch.tensor(a_t, dtype=h_t.dtype, device=h_t.device) # Fix dims: make a_t [B, action_size] if a_t.dim() == 0: a_t = a_t.view(1, 1) if a_t.dim() == 1: # If length equals batch size -> treat as per-batch scalar action -> make [B,1] if a_t.numel() == h_t.size(0): a_t = a_t.view(-1, 1).to(h_t.device).to(h_t.dtype) else: # treat as single action vector -> expand to batch a_t = ( a_t.view(1, -1).expand(h_t.size(0), -1).to(h_t.device).to(h_t.dtype) ) elif a_t.dim() == 2 and a_t.size(0) != h_t.size(0): # if batch dim mismatches but first dim is 1, expand if a_t.size(0) == 1: a_t = a_t.expand(h_t.size(0), -1) else: raise ValueError( f"Action batch size {a_t.size(0)} != state batch size {h_t.size(0)}" ) # concatenate stochastic state and action, then pass through latent->state inp = torch.cat([s_t, a_t], dim=1) h = self.act_fn(self.lat_act_layer(inp)) h = self.grucell(h, h_t) return h
[docs] def state_prior(self, h_t, sample=False): """Returns the prior distribution over the latent state given the deterministic state""" z = self.act_fn(self.fc_prior_1(h_t)) m = self.fc_prior_m(z) s = F.softplus(self.fc_prior_s(z)) + 0.1 s = torch.clamp(s, min=1e-6, max=10.0) if sample: return m + torch.rand_like(m) * s return m, s
[docs] def state_posterior(self, h_t, e_t, sample=False): """Returns the state prior given the deterministic state and obs""" z = torch.cat([h_t, e_t], dim=1) z = self.act_fn(self.fc_posterior_1(z)) m = self.fc_posterior_m(z) s = F.softplus(self.fc_posterior_s(z)) + 0.1 s = torch.clamp(s, min=1e-6, max=10.0) if sample: return m + torch.rand_like(m) * s return m, s
[docs] def pred_reward(self, h_t, s_t): r = self.act_fn(self.fc_reward_1(torch.cat([h_t, s_t], dim=-1))) r = self.act_fn(self.fc_reward_2(r)) r = self.fc_reward_3(r) return r.squeeze()
[docs] def rollout_prior(self, act, h_t, s_t): states, latents = [], [] for a_t in torch.unbind(act, dim=0): h_t = self.deterministic_state_fwd(h_t, s_t, a_t) s_t = self.state_prior(h_t) states.append(h_t) latents.append(s_t) Normal(*map(torch.stack, zip(*s_t))) return torch.stack(states), torch.stack(latents)
[docs] def forward(self, x, u): """ Forward through the RSSM for a batch of sequences. Inputs: x: Tensor [B, T+1, C, H, W] (observations including initial frame) u: Tensor [B, T, action_size] (actions for T steps) Returns: states: list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size] """ B = x.size(0) T = u.size(1) device = x.device # encode all frames at once: [B*(T+1), C, H, W] -> [B, T+1, embed_size] x_flat = x.view(B * (T + 1), *x.shape[2:]) e_flat = self.encoder(x_flat) e = e_flat.view(B, T + 1, -1) h = torch.zeros(B, self.state_size, device=device) s = torch.zeros(B, self.latent_size, device=device) states, priors, posteriors = [], [], [] for t in range(T): a_t = u[:, t] h = self.deterministic_state_fwd(h, s, a_t) prior = self.state_prior(h) # (m, s) posterior = self.state_posterior(h, e[:, t + 1]) # (m, s) states.append(h) priors.append(prior) posteriors.append(posterior) s = posterior[0] return states, priors, posteriors