Source code for world_models.controller.rssm_policy

import torch
from torch.distributions import Normal


[docs] class RSSMPolicy: """Model-predictive controller using Cross-Entropy Method (CEM) with RSSM. Plans actions by optimizing a sequence of future actions in the RSSM's latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns. Algorithm: 1. Initialize Gaussian distribution over action sequences 2. Sample N candidate action sequences 3. Rollout each sequence in RSSM latent space 4. Score by predicted cumulative rewards 5. Keep top K candidates, fit Gaussian to them 6. Repeat for T iterations 7. Execute first action from best sequence Why latent space planning? - Images are high-dimensional; latent states are compact - Enables thousands of rollouts in parallel - Dynamics model is more accurate in latent space Args: model: RSSM instance for latent dynamics planning_horizon: Number of future steps to plan (H) num_candidates: Number of action sequences to sample (N) num_iterations: CEM refinement iterations (T) top_candidates: Number of best candidates to keep (K) device: torch device Usage with Planet agent: policy = RSSMPolicy( model=rssm, planning_horizon=12, num_candidates=1000, num_iterations=8, top_candidates=100, device='cuda' ) policy.reset() action = policy.poll(observation) # (1, action_dim) # For continuous control: next_obs, reward, done, info = env.step(action) Comparison with Dreamer: - RSSMPolicy: Online planning, chooses actions by optimization at each step - DreamerActor: Train actor network to predict actions from states - Dreamer is more sample-efficient for complex tasks; CEM is more flexible """ def __init__( self, model, planning_horizon, num_candidates, num_iterations, top_candidates, device, ): super().__init__() self.rssm = model self.N = num_candidates self.K = top_candidates self.T = num_iterations self.H = planning_horizon self.d = self.rssm.action_size self.device = device self.state_size = self.rssm.state_size self.latent_size = self.rssm.latent_size
[docs] def reset(self): self.h = torch.zeros(1, self.state_size).to(self.device) self.s = torch.zeros(1, self.latent_size).to(self.device) self.a = torch.zeros(1, self.d).to(self.device)
def _poll(self, obs): self.mu = torch.zeros(self.H, self.d).to(self.device) self.stddev = torch.ones(self.H, self.d).to(self.device) # observation could be of shape [CHW] but only 1 timestep assert len(obs.shape) == 3, "obs should be [CHW]" self.h, self.s = self.rssm.get_init_state( self.rssm.encoder(obs[None]), self.h, self.s, self.a ) for _ in range(self.T): rwds = torch.zeros(self.N).to(self.device) actions = Normal(self.mu, self.stddev).sample((self.N,)) h_t = self.h.clone().expand(self.N, -1) s_t = self.s.clone().expand(self.N, -1) for a_t in torch.unbind(actions, dim=1): h_t = self.rssm.deterministic_state_fwd(h_t, s_t, a_t) s_t = self.rssm.state_prior(h_t, sample=True) rwds += self.rssm.pred_reward(h_t, s_t).squeeze(-1) _, k = torch.topk(rwds, self.K, dim=0, largest=True, sorted=False) self.mu = actions[k].mean(dim=0) self.stddev = actions[k].std(dim=0, unbiased=False) self.a = self.mu[0:1]
[docs] def poll(self, observation, explore=False): with torch.no_grad(): self._poll(observation) if explore: self.a += torch.randn_like(self.a) * 0.3 return self.a