Source code for world_models.vision.planet_encoder

import torch.nn as nn
import torch.nn.functional as F


[docs] class CNNEncoder(nn.Module): """A Convolutional Neural Network (CNN) encoder for processing image inputs.""" def __init__(self, embedding_size, activation_function="relu"): super().__init__() self.act_fn = getattr(F, activation_function) self.embedding_size = embedding_size self.conv1 = nn.Conv2d(3, 32, 4, stride=2) self.conv2 = nn.Conv2d(32, 64, 4, stride=2) self.conv3 = nn.Conv2d(64, 128, 4, stride=2) self.conv4 = nn.Conv2d(128, 256, 4, stride=2) if embedding_size == 1024: self.fc = nn.Identity() else: self.fc = nn.Linear(1024, embedding_size)
[docs] def forward(self, observation): hidden = self.act_fn(self.conv1(observation)) hidden = self.act_fn(self.conv2(hidden)) hidden = self.act_fn(self.conv3(hidden)) hidden = self.act_fn(self.conv4(hidden)) hidden = hidden.view(-1, 1024) hidden = self.fc(hidden) return hidden