Source code for world_models.models.controller

"""Linear Controller for World Models.

This module provides a simple linear controller that maps latent states
and recurrent hidden states to actions. The controller is trained using
CMA-ES (Covariance Matrix Adaptation Evolution Strategy).

Reference:
    Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution.
    https://arxiv.org/abs/1805.11111
"""

import torch
import torch.nn as nn


[docs] class Controller(nn.Module): """Linear controller that maps latent + hidden state to actions. This is a simple linear controller that takes the latent state and recurrent hidden state as input and outputs actions. It is trained separately from the world model using black-box optimization (CMA-ES). Attributes: latent_size: Dimensionality of latent state from VAE. hidden_size: Dimensionality of RSSM hidden state. action_size: Dimensionality of action space. Example: >>> controller = Controller(latent_size=32, hidden_size=200, action_size=3) >>> state = torch.cat([latent, hidden], dim=-1) >>> action = controller(state) """ def __init__(self, latent_size: int, hidden_size: int, action_size: int): """Initialize the linear controller. Args: latent_size: Dimensionality of latent state from VAE. hidden_size: Dimensionality of RSSM hidden state. action_size: Dimensionality of action space. """ super(Controller, self).__init__() self.latent_size = latent_size self.hidden_size = hidden_size self.action_size = action_size self.fc = nn.Linear(latent_size + hidden_size, action_size)
[docs] def forward(self, state: torch.Tensor) -> torch.Tensor: """Compute actions from latent and hidden states. Args: state: Concatenated [latent, hidden] state tensor. Returns: Action tensor of shape (batch, action_size). """ return self.fc(state)