Source code for world_models.vision.planet_encoder
import torch.nn as nn
import torch.nn.functional as F
[docs]
class CNNEncoder(nn.Module):
"""A Convolutional Neural Network (CNN) encoder for processing image inputs."""
def __init__(self, embedding_size, activation_function="relu"):
super().__init__()
self.act_fn = getattr(F, activation_function)
self.embedding_size = embedding_size
self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
if embedding_size == 1024:
self.fc = nn.Identity()
else:
self.fc = nn.Linear(1024, embedding_size)
[docs]
def forward(self, observation):
hidden = self.act_fn(self.conv1(observation))
hidden = self.act_fn(self.conv2(hidden))
hidden = self.act_fn(self.conv3(hidden))
hidden = self.act_fn(self.conv4(hidden))
hidden = hidden.view(-1, 1024)
hidden = self.fc(hidden)
return hidden