Source code for world_models.vision.planet_decoder

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class CNNDecoder(nn.Module): """A Convolutional Neural Network (CNN) decoder for reconstructing image outputs.""" def __init__( self, state_size, latent_size, embedding_size, activation_function="relu" ): super().__init__() self.act_fn = getattr(F, activation_function) self.embedding_size = embedding_size self.fc1 = nn.Linear(latent_size + state_size, embedding_size) self.conv1 = nn.ConvTranspose2d(embedding_size, 128, 5, stride=2) self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)
[docs] def forward(self, latent, state): hidden = self.fc1(torch.cat([state, latent], dim=1)) hidden = hidden.view(-1, self.embedding_size, 1, 1) hidden = self.act_fn(self.conv1(hidden)) hidden = self.act_fn(self.conv2(hidden)) hidden = self.act_fn(self.conv3(hidden)) observation = self.conv4(hidden) return observation