import torch
import torch.nn as nn
_str_to_activation = {
"relu": nn.ReLU(),
"elu": nn.ELU(),
"tanh": nn.Tanh(),
"leaky_relu": nn.LeakyReLU(),
"sigmoid": nn.Sigmoid(),
"selu": nn.SELU(),
"softplus": nn.Softplus(),
"identity": nn.Identity(),
}
[docs]
class ConvEncoder(nn.Module):
"""Convolutional observation encoder used by Dreamer world models.
This encoder transforms raw image observations (typically RGB frames from
environments like Atari or DeepMind Control) into compact latent embeddings
that can be processed by the RSSM (Recurrent State-Space Model).
Architecture:
Input: (B, C, H, W) raw images, values in [-0.5, 0.5]
Process: 4 convolutional layers with stride 2, halving spatial dimensions
Output: (B, embed_size) compact representation
The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels.
After convolutions, a fully connected layer projects from 1024 features to
the desired embedding size.
Usage with Dreamer:
encoder = ConvEncoder(
input_shape=(3, 64, 64), # RGB 64x64 images
embed_size=256, # RSSM observation embedding size
activation='relu' # Activation function
)
obs_embedding = encoder(observation) # (B, 256)
Args:
input_shape: Tuple (C, H, W) for input images, typically (3, 64, 64)
embed_size: Output embedding dimension, typically 256 or 1024
activation: Activation function name ('relu', 'elu', 'tanh', etc.)
depth: Base channel depth for first layer (default 32)
"""
def __init__(self, input_shape, embed_size, activation, depth=32):
super().__init__()
self.input_shape = input_shape
self.act_fn = _str_to_activation[activation]
self.depth = depth
self.kernels = [4, 4, 4, 4]
self.embed_size = embed_size
layers = []
for i, kernel_size in enumerate(self.kernels):
in_ch = input_shape[0] if i == 0 else self.depth * (2 ** (i - 1))
out_ch = self.depth * (2**i)
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, stride=2))
layers.append(self.act_fn)
self.conv_block = nn.Sequential(*layers)
self.fc = (
nn.Identity()
if self.embed_size == 1024
else nn.Linear(1024, self.embed_size)
)
[docs]
def forward(self, inputs):
reshaped = inputs.reshape(-1, *self.input_shape)
embed = self.conv_block(reshaped)
embed = torch.reshape(embed, (*inputs.shape[:-3], -1))
embed = self.fc(embed)
return embed