Source code for world_models.models.dreamer_rssm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions

_str_to_activation = {
    "relu": nn.ReLU(),
    "elu": nn.ELU(),
    "tanh": nn.Tanh(),
    "leaky_relu": nn.LeakyReLU(),
    "sigmoid": nn.Sigmoid(),
    "selu": nn.SELU(),
    "softplus": nn.Softplus(),
    "identity": nn.Identity(),
}


[docs] class RSSM(nn.Module): """Recurrent State-Space Model used by Dreamer latent dynamics learning. The RSSM maintains deterministic recurrent state and stochastic latent state, and provides transition/posterior updates plus rollout helpers. """ def __init__( self, action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation, ): super().__init__() self.action_size = action_size self.stoch_size = stoch_size self.deter_size = deter_size # GRU hidden units self.hidden_size = hidden_size # intermediate fc_layers hidden units self.embedding_size = obs_embed_size self.act_fn = _str_to_activation[activation] self.rnn = nn.GRUCell(self.deter_size, self.deter_size) self.fc_state_action = nn.Linear( self.stoch_size + self.action_size, self.deter_size ) self.fc_embed_prior = nn.Linear(self.deter_size, self.hidden_size) self.fc_state_prior = nn.Linear(self.hidden_size, 2 * self.stoch_size) self.fc_embed_posterior = nn.Linear( self.embedding_size + self.deter_size, self.hidden_size ) self.fc_state_posterior = nn.Linear(self.hidden_size, 2 * self.stoch_size)
[docs] def init_state(self, batch_size, device): return dict( mean=torch.zeros(batch_size, self.stoch_size).to(device), std=torch.zeros(batch_size, self.stoch_size).to(device), stoch=torch.zeros(batch_size, self.stoch_size).to(device), deter=torch.zeros(batch_size, self.deter_size).to(device), )
[docs] def get_dist(self, mean, std): distribution = distributions.Normal(mean, std) distribution = distributions.independent.Independent(distribution, 1) return distribution
[docs] def observe_step(self, prev_state, prev_action, obs_embed, nonterm=1.0): prior = self.imagine_step(prev_state, prev_action, nonterm) posterior_embed = self.act_fn( self.fc_embed_posterior(torch.cat([obs_embed, prior["deter"]], dim=-1)) ) posterior = self.fc_state_posterior(posterior_embed) mean, std = torch.chunk(posterior, 2, dim=-1) std = F.softplus(std) + 0.1 sample = mean + torch.randn_like(mean) * std posterior = {"mean": mean, "std": std, "stoch": sample, "deter": prior["deter"]} return prior, posterior
[docs] def imagine_step(self, prev_state, prev_action, nonterm=1.0): state_action = self.act_fn( self.fc_state_action( torch.cat([prev_state["stoch"] * nonterm, prev_action], dim=-1) ) ) deter = self.rnn(state_action, prev_state["deter"] * nonterm) prior_embed = self.act_fn(self.fc_embed_prior(deter)) mean, std = torch.chunk(self.fc_state_prior(prior_embed), 2, dim=-1) std = F.softplus(std) + 0.1 sample = mean + torch.randn_like(mean) * std prior = {"mean": mean, "std": std, "stoch": sample, "deter": deter} return prior
[docs] def observe_rollout(self, obs_embed, actions, nonterms, prev_state, horizon): priors = [] posteriors = [] for t in range(horizon): prev_action = actions[t] * nonterms[t] prior_state, posterior_state = self.observe_step( prev_state, prev_action, obs_embed[t], nonterms[t] ) priors.append(prior_state) posteriors.append(posterior_state) prev_state = posterior_state priors = self.stack_states(priors, dim=0) posteriors = self.stack_states(posteriors, dim=0) return priors, posteriors
[docs] def imagine_rollout(self, actor, prev_state, horizon): rssm_state = prev_state next_states = [] for t in range(horizon): action = actor( torch.cat([rssm_state["stoch"], rssm_state["deter"]], dim=-1).detach() ) rssm_state = self.imagine_step(rssm_state, action) next_states.append(rssm_state) next_states = self.stack_states(next_states) return next_states
[docs] def stack_states(self, states, dim=0): return dict( mean=torch.stack([state["mean"] for state in states], dim=dim), std=torch.stack([state["std"] for state in states], dim=dim), stoch=torch.stack([state["stoch"] for state in states], dim=dim), deter=torch.stack([state["deter"] for state in states], dim=dim), )
[docs] def detach_state(self, state): return dict( mean=state["mean"].detach(), std=state["std"].detach(), stoch=state["stoch"].detach(), deter=state["deter"].detach(), )
[docs] def seq_to_batch(self, state): return dict( mean=torch.reshape( state["mean"], ( state["mean"].shape[0] * state["mean"].shape[1], *state["mean"].shape[2:], ), ), std=torch.reshape( state["std"], ( state["std"].shape[0] * state["std"].shape[1], *state["std"].shape[2:], ), ), stoch=torch.reshape( state["stoch"], ( state["stoch"].shape[0] * state["stoch"].shape[1], *state["stoch"].shape[2:], ), ), deter=torch.reshape( state["deter"], ( state["deter"].shape[0] * state["deter"].shape[1], *state["deter"].shape[2:], ), ), )