Source code for world_models.vision.planet_encoder
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class CNNEncoder(nn.Module):
"""A Convolutional Neural Network (CNN) encoder for processing image inputs."""
def __init__(self, embedding_size: int, activation_function: str = "relu") -> None:
super().__init__()
self.act_fn = getattr(F, activation_function)
self.embedding_size = embedding_size
self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
self.fc: nn.Linear | nn.Identity = (
nn.Identity() if embedding_size == 1024 else nn.Linear(1024, embedding_size)
)
[docs]
def forward(self, observation: torch.Tensor) -> torch.Tensor:
hidden = self.act_fn(self.conv1(observation))
hidden = self.act_fn(self.conv2(hidden))
hidden = self.act_fn(self.conv3(hidden))
hidden = self.act_fn(self.conv4(hidden))
hidden = hidden.view(-1, 1024)
hidden = self.fc(hidden)
return hidden