Source code for world_models.vision.dreamer_encoder

import torch
import torch.nn as nn

_str_to_activation = {
    "relu": nn.ReLU(),
    "elu": nn.ELU(),
    "tanh": nn.Tanh(),
    "leaky_relu": nn.LeakyReLU(),
    "sigmoid": nn.Sigmoid(),
    "selu": nn.SELU(),
    "softplus": nn.Softplus(),
    "identity": nn.Identity(),
}


[docs] class ConvEncoder(nn.Module): """Convolutional observation encoder used by Dreamer world models. Encodes image observations into compact embeddings consumed by the RSSM posterior update network. """ def __init__(self, input_shape, embed_size, activation, depth=32): super().__init__() self.input_shape = input_shape self.act_fn = _str_to_activation[activation] self.depth = depth self.kernels = [4, 4, 4, 4] self.embed_size = embed_size layers = [] for i, kernel_size in enumerate(self.kernels): in_ch = input_shape[0] if i == 0 else self.depth * (2 ** (i - 1)) out_ch = self.depth * (2**i) layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, stride=2)) layers.append(self.act_fn) self.conv_block = nn.Sequential(*layers) self.fc = ( nn.Identity() if self.embed_size == 1024 else nn.Linear(1024, self.embed_size) )
[docs] def forward(self, inputs): reshaped = inputs.reshape(-1, *self.input_shape) embed = self.conv_block(reshaped) embed = torch.reshape(embed, (*inputs.shape[:-3], -1)) embed = self.fc(embed) return embed