Source code for world_models.vision.dreamer_decoder

import torch
import torch.nn as nn
from torch.distributions import Normal, TransformedDistribution, Bernoulli
from torch.distributions.independent import Independent
import numpy as np
import torch.distributions as distributions
from torch.distributions import constraints
import torch.nn.functional as F


_str_to_activation = {
    "relu": nn.ReLU(),
    "elu": nn.ELU(),
    "tanh": nn.Tanh(),
    "leaky_relu": nn.LeakyReLU(),
    "sigmoid": nn.Sigmoid(),
    "selu": nn.SELU(),
    "softplus": nn.Softplus(),
    "identity": nn.Identity(),
}


[docs] class TanhBijector(distributions.Transform): """Bijective `tanh` transform used to squash Gaussian actions to `[-1, 1]`. Provides stable inverse and log-determinant Jacobian computations for transformed-action distributions. """ def __init__(self): super().__init__() self.bijective = True self.domain = constraints.real self.codomain = constraints.interval(-1.0, 1.0) @property def sign(self): return 1.0 def _call(self, x): return torch.tanh(x)
[docs] def atanh(self, x): return 0.5 * torch.log((1 + x) / (1 - x))
def _inverse(self, y: torch.Tensor): y = torch.where( (torch.abs(y) <= 1.0), torch.clamp(y, -0.99999997, 0.99999997), y ) y = self.atanh(y) return y
[docs] def log_abs_det_jacobian(self, x, y): return 2.0 * (np.log(2) - x - F.softplus(-2.0 * x))
[docs] class ConvDecoder(nn.Module): """Convolutional Dreamer decoder from latent features to image distributions. Maps concatenated stochastic/deterministic states into transposed-conv outputs and returns a factorized Normal distribution over pixels. """ def __init__(self, stoch_size, deter_size, output_shape, activation, depth=32): super().__init__() self.output_shape = output_shape self.depth = depth self.kernels = [5, 5, 6, 6] self.act_fn = _str_to_activation[activation] self.dense = nn.Linear(stoch_size + deter_size, 32 * self.depth) layers = [] for i, kernel_size in enumerate(self.kernels): in_ch = ( 32 * self.depth if i == 0 else self.depth * (2 ** (len(self.kernels) - 1 - i)) ) out_ch = ( output_shape[0] if i == len(self.kernels) - 1 else self.depth * (2 ** (len(self.kernels) - 2 - i)) ) layers.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride=2)) if i != len(self.kernels) - 1: layers.append(self.act_fn) self.convtranspose = nn.Sequential(*layers)
[docs] def forward(self, features): out_batch_shape = features.shape[:-1] out = self.dense(features) out = torch.reshape(out, [-1, 32 * self.depth, 1, 1]) out = self.convtranspose(out) mean = torch.reshape(out, (*out_batch_shape, *self.output_shape)) out_dist = Independent(Normal(mean, 1), len(self.output_shape)) return out_dist
[docs] class DenseDecoder(nn.Module): """MLP decoder for reward/value/discount prediction from latent features. The output distribution type is configurable (`normal`, `binary`, or raw tensor). """ def __init__( self, stoch_size, deter_size, output_shape, n_layers, units, activation, dist ): super().__init__() self.input_size = stoch_size + deter_size self.output_shape = output_shape self.n_layers = n_layers self.units = units self.act_fn = _str_to_activation[activation] self.dist = dist layers = [] for i in range(self.n_layers): in_ch = self.input_size if i == 0 else self.units out_ch = self.units layers.append(nn.Linear(in_ch, out_ch)) layers.append(self.act_fn) layers.append(nn.Linear(self.units, int(np.prod(self.output_shape)))) self.model = nn.Sequential(*layers)
[docs] def forward(self, features): out = self.model(features) if self.dist == "normal": return Independent(Normal(out, 1), len(self.output_shape)) if self.dist == "binary": return Independent(Bernoulli(logits=out), len(self.output_shape)) if self.dist == "none": return out raise NotImplementedError(self.dist)
[docs] class SampleDist: """Distribution wrapper that estimates statistics via Monte Carlo sampling. Provides approximated `mean`, `mode`, and `entropy` helpers for transformed distributions where analytic forms may be inconvenient. """ def __init__(self, dist, samples=100): self._dist = dist self._samples = samples @property def name(self): return "SampleDist" def __getattr__(self, name): return getattr(self._dist, name)
[docs] def mean(self): sample = self._dist.rsample(self._samples) return torch.mean(sample, 0)
[docs] def mode(self): dist = self._dist.expand((self._samples, *self._dist.batch_shape)) sample = dist.rsample() logprob = dist.log_prob(sample) batch_size = sample.size(1) feature_size = sample.size(2) indices = ( torch.argmax(logprob, dim=0) .reshape(1, batch_size, 1) .expand(1, batch_size, feature_size) ) return torch.gather(sample, 0, indices).squeeze(0)
[docs] def entropy(self): dist = self._dist.expand((self._samples, *self._dist.batch_shape)) sample = dist.rsample() logprob = dist.log_prob(sample) return -torch.mean(logprob, 0)
[docs] def sample(self): return self._dist.sample()
[docs] class ActionDecoder(nn.Module): """Dreamer actor head producing squashed continuous actions from latent features. Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise. """ def __init__( self, action_size, stoch_size, deter_size, n_layers, units, activation, min_std=1e-4, init_std=5, mean_scale=5, ): super().__init__() self.action_size = action_size self.stoch_size = stoch_size self.deter_size = deter_size self.units = units self.act_fn = _str_to_activation[activation] self.n_layers = n_layers self._min_std = min_std self._init_std = init_std self._mean_scale = mean_scale layers = [] for i in range(self.n_layers): in_ch = self.stoch_size + self.deter_size if i == 0 else self.units out_ch = self.units layers.append(nn.Linear(in_ch, out_ch)) layers.append(self.act_fn) layers.append(nn.Linear(self.units, 2 * self.action_size)) self.action_model = nn.Sequential(*layers)
[docs] def forward(self, features, deter=False): out = self.action_model(features) mean, std = torch.chunk(out, 2, dim=-1) raw_init_std = np.log(np.exp(self._init_std) - 1) action_mean = self._mean_scale * torch.tanh(mean / self._mean_scale) action_std = F.softplus(std + raw_init_std) + self._min_std dist = Normal(action_mean, action_std) dist = TransformedDistribution(dist, TanhBijector()) dist = Independent(dist, 1) dist = SampleDist(dist) if deter: return dist.mode() else: return dist.rsample()
[docs] def add_exploration(self, action, action_noise=0.3): return torch.clamp(Normal(action, action_noise).rsample(), -1, 1)