Source code for world_models.controller.rssm_policy
"""RSSM-based policy for model-predictive control.
This module provides the RSSMPolicy class that implements model-predictive control
using the RSSM (Recurrent State Space Model) latent dynamics model. The policy uses
a Cross-Entropy Method (CEM) for planning actions in latent space.
Reference:
Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution.
https://arxiv.org/abs/1805.11111
"""
import torch
from torch.distributions import Normal
[docs]
class RSSMPolicy:
"""Model-predictive controller that plans actions with the RSSM latent model.
The policy uses a Cross-Entropy Method style loop: it samples candidate
action sequences, rolls them forward in latent space, scores predicted
returns, and refits a Gaussian proposal to top-performing candidates.
Attributes:
rssm: The RSSM world model.
N: Number of candidate action sequences to sample.
K: Number of top candidates to use for updating the proposal.
T: Number of CEM iterations per planning step.
H: Planning horizon (number of future steps to consider).
d: Action dimensionality.
device: Device to run computations on.
state_size: Hidden state dimensionality.
latent_size: Latent state dimensionality.
Example:
>>> policy = RSSMPolicy(
... model=rssm,
... planning_horizon=12,
... num_candidates=1000,
... num_iterations=5,
... top_candidates=100,
... device='cuda'
... )
>>> policy.reset()
>>> action = policy.poll(observation)
"""
def __init__(
self,
model,
planning_horizon: int,
num_candidates: int,
num_iterations: int,
top_candidates: int,
device: str,
):
"""Initialize the RSSM policy.
Args:
model: The RSSM world model.
planning_horizon: Number of future steps to plan ahead.
num_candidates: Number of candidate action sequences to sample.
num_iterations: Number of CEM iterations per planning step.
top_candidates: Number of top candidates to keep for refitting.
device: Device to run computations on.
"""
super().__init__()
self.rssm = model
self.N = num_candidates
self.K = top_candidates
self.T = num_iterations
self.H = planning_horizon
self.d = self.rssm.action_size
self.device = device
self.state_size = self.rssm.state_size
self.latent_size = self.rssm.latent_size
[docs]
def reset(self):
"""Reset the policy state.
Initializes the hidden state, latent state, and action to zeros.
Should be called at the beginning of each episode.
"""
self.h = torch.zeros(1, self.state_size).to(self.device)
self.s = torch.zeros(1, self.latent_size).to(self.device)
self.a = torch.zeros(1, self.d).to(self.device)
def _poll(self, obs):
"""Perform CEM planning to select actions.
This internal method runs the Cross-Entropy Method optimization
to find the best action sequence given the current observation.
Args:
obs: Current observation tensor of shape (channels, height, width).
"""
self.mu = torch.zeros(self.H, self.d).to(self.device)
self.stddev = torch.ones(self.H, self.d).to(self.device)
assert len(obs.shape) == 3, "obs should be [CHW]"
self.h, self.s = self.rssm.get_init_state(
self.rssm.encoder(obs[None]), self.h, self.s, self.a
)
for _ in range(self.T):
rwds = torch.zeros(self.N).to(self.device)
actions = Normal(self.mu, self.stddev).sample((self.N,))
h_t = self.h.clone().expand(self.N, -1)
s_t = self.s.clone().expand(self.N, -1)
for a_t in torch.unbind(actions, dim=1):
h_t = self.rssm.deterministic_state_fwd(h_t, s_t, a_t)
s_t = self.rssm.state_prior(h_t, sample=True)
rwds += self.rssm.pred_reward(h_t, s_t).squeeze(-1)
_, k = torch.topk(rwds, self.K, dim=0, largest=True, sorted=False)
self.mu = actions[k].mean(dim=0)
self.stddev = actions[k].std(dim=0, unbiased=False)
self.a = self.mu[0:1]
[docs]
def poll(self, observation: torch.Tensor, explore: bool = False) -> torch.Tensor:
"""Get action for given observation.
Args:
observation: Current observation tensor of shape (channels, height, width).
explore: If True, add exploration noise to the selected action.
Returns:
Action tensor of shape (1, action_size).
"""
with torch.no_grad():
self._poll(observation)
if explore:
self.a += torch.randn_like(self.a) * 0.3
return self.a