Source code for world_models.models.modular_rssm
"""Modular RSSM with swappable encoder/decoder/backbone components.
This module provides a flexible architecture for world model research,
allowing researchers to easily swap different encoder, decoder, and backbone
implementations for ablations and experimentation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distributions
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List, Tuple, Union
_str_to_activation = {
"relu": nn.ReLU(),
"elu": nn.ELU(),
"tanh": nn.Tanh(),
"leaky_relu": nn.LeakyReLU(),
"sigmoid": nn.Sigmoid(),
"selu": nn.SELU(),
"softplus": nn.Softplus(),
"identity": nn.Identity(),
"gelu": nn.GELU(),
}
[docs]
class EncoderBase(nn.Module, ABC):
"""Abstract base class for observation encoders."""
[docs]
@abstractmethod
def forward(self, obs: torch.Tensor) -> torch.Tensor:
"""Encode observations to embeddings."""
pass
# Annotate expected attribute so type checkers know subclasses expose it.
embed_size: int
[docs]
def get_embed_size(self) -> int:
"""Return the embedding size. Override in subclasses."""
raise NotImplementedError
[docs]
class DecoderBase(nn.Module, ABC):
"""Abstract base class for observation decoders."""
[docs]
@abstractmethod
def forward(self, features: torch.Tensor) -> Any:
"""Decode latent features to observation distributions."""
pass
[docs]
class BackboneBase(nn.Module, ABC):
"""Abstract base class for recurrent dynamics backbones."""
[docs]
@abstractmethod
def forward(
self,
state: Dict[str, torch.Tensor],
action: torch.Tensor,
obs_embed: Optional[torch.Tensor] = None,
nonterm: float = 1.0,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""Process one step of dynamics. Returns (prior, posterior)."""
pass
[docs]
@abstractmethod
def init_state(
self, batch_size: int, device: torch.device
) -> Dict[str, torch.Tensor]:
"""Initialize hidden state."""
pass
# Expose expected size attributes to help mypy understand types on
# instances of concrete backbones.
stoch_size: int
deter_size: int
[docs]
class ConvEncoder(EncoderBase):
"""Convolutional encoder from Dreamer (image observations)."""
def __init__(
self,
input_shape: Tuple[int, int, int],
embed_size: int,
activation: str = "elu",
depth: int = 32,
):
super().__init__()
self.input_shape = input_shape
self.embed_size = embed_size
self.act_fn = _str_to_activation[activation]
self.depth = depth
self.kernels = [4, 4, 4, 4]
# Use a generic module list since we append diverse nn.Module objects
layers: list[nn.Module] = []
for i, kernel_size in enumerate(self.kernels):
in_ch = input_shape[0] if i == 0 else self.depth * (2 ** (i - 1))
out_ch = self.depth * (2**i)
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size, stride=2))
layers.append(self.act_fn)
self.conv_block = nn.Sequential(*layers)
self.fc = (
nn.Linear(1024, self.embed_size)
if self.embed_size != 1024
else nn.Identity()
)
[docs]
def forward(self, obs: torch.Tensor) -> torch.Tensor:
reshaped = obs.reshape(-1, *self.input_shape)
embed = self.conv_block(reshaped)
embed = torch.reshape(embed, (*obs.shape[:-3], -1))
embed = self.fc(embed)
return embed
[docs]
class MLPEncoder(EncoderBase):
"""MLP encoder for state-based observations."""
def __init__(
self,
input_dim: int,
embed_size: int,
hidden_sizes: List[int] = [256, 256],
activation: str = "elu",
):
super().__init__()
self.input_dim = input_dim
self.embed_size = embed_size
self.act_fn = _str_to_activation[activation]
# Use a generic module list since we append diverse nn.Module objects
layers: list[nn.Module] = []
in_dim = input_dim
for hidden_size in hidden_sizes:
layers.extend([nn.Linear(in_dim, hidden_size), self.act_fn])
in_dim = hidden_size
layers.append(nn.Linear(in_dim, embed_size))
self.model = nn.Sequential(*layers)
[docs]
class ViTEncoder(EncoderBase):
"""Vision Transformer encoder for image observations."""
def __init__(
self,
input_shape: Tuple[int, int, int],
embed_size: int,
patch_size: int = 8,
depth: int = 6,
num_heads: int = 8,
mlp_ratio: float = 4.0,
activation: str = "gelu",
):
super().__init__()
self.input_shape = input_shape
self.embed_size = embed_size
self.patch_size = patch_size
c, h, w = input_shape
self.num_patches = (h // patch_size) * (w // patch_size)
self.patch_embed = nn.Conv2d(
c, embed_size, kernel_size=patch_size, stride=patch_size
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_size))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_size))
self.blocks = nn.ModuleList(
[
TransformerBlock(embed_size, num_heads, mlp_ratio, activation)
for _ in range(depth)
]
)
self.norm = nn.LayerNorm(embed_size)
[docs]
def forward(self, obs: torch.Tensor) -> torch.Tensor:
B = obs.shape[0]
x = self.patch_embed(obs)
x = x.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
for block in self.blocks:
x = block(x)
return self.norm(x[:, 0])
[docs]
class TransformerBlock(nn.Module):
"""Transformer block for ViT encoder."""
def __init__(
self, embed_size: int, num_heads: int, mlp_ratio: float, activation: str
):
super().__init__()
self.norm1 = nn.LayerNorm(embed_size)
self.attn = nn.MultiheadAttention(embed_size, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(embed_size)
self.mlp = nn.Sequential(
nn.Linear(embed_size, int(embed_size * mlp_ratio)),
_str_to_activation[activation],
nn.Linear(int(embed_size * mlp_ratio), embed_size),
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.mlp(self.norm2(x))
return x
[docs]
class ConvDecoder(DecoderBase):
"""Convolutional decoder for image observations."""
def __init__(
self,
stoch_size: int,
deter_size: int,
output_shape: Tuple[int, int, int],
activation: str = "elu",
depth: int = 32,
):
super().__init__()
self.output_shape = output_shape
self.depth = depth
self.kernels = [5, 5, 6, 6]
self.act_fn = _str_to_activation[activation]
self.dense = nn.Linear(stoch_size + deter_size, 32 * depth)
layers: list[nn.Module] = []
for i, kernel_size in enumerate(self.kernels):
in_ch = 32 * depth if i == 0 else depth * (2 ** (len(self.kernels) - 1 - i))
out_ch = (
output_shape[0]
if i == len(self.kernels) - 1
else depth * (2 ** (len(self.kernels) - 2 - i))
)
layers.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride=2))
if i != len(self.kernels) - 1:
layers.append(self.act_fn)
self.convtranspose = nn.Sequential(*layers)
[docs]
def forward(self, features: torch.Tensor) -> Any:
out_batch_shape = features.shape[:-1]
out = self.dense(features)
out = torch.reshape(out, [-1, 32 * self.depth, 1, 1])
out = self.convtranspose(out)
mean = torch.reshape(out, (*out_batch_shape, *self.output_shape))
return distributions.independent.Independent(
distributions.Normal(mean, 1), len(self.output_shape)
)
[docs]
class MLPDecoder(DecoderBase):
"""MLP decoder for state-based observations."""
def __init__(
self,
stoch_size: int,
deter_size: int,
output_dim: int,
hidden_sizes: List[int] = [256, 256],
activation: str = "elu",
dist: str = "normal",
):
super().__init__()
self.input_size = stoch_size + deter_size
self.output_dim = output_dim
self.act_fn = _str_to_activation[activation]
self.dist = dist
layers = []
in_dim = self.input_size
for hidden_size in hidden_sizes:
layers.extend([nn.Linear(in_dim, hidden_size), self.act_fn])
in_dim = hidden_size
layers.append(nn.Linear(in_dim, output_dim))
self.model = nn.Sequential(*layers)
[docs]
def forward(self, features: torch.Tensor) -> Any:
out = self.model(features)
if self.dist == "normal":
return distributions.independent.Independent(
distributions.Normal(out, 1), 1
)
return out
[docs]
class GRUBackbone(BackboneBase):
"""GRU-based recurrent dynamics backbone (standard RSSM)."""
def __init__(
self,
action_size: int,
stoch_size: int,
deter_size: int,
hidden_size: int,
embed_size: int,
activation: str = "elu",
):
super().__init__()
self.action_size = action_size
self.stoch_size = stoch_size
self.deter_size = deter_size
self.hidden_size = hidden_size
self.embed_size = embed_size
self.act_fn = _str_to_activation[activation]
self.rnn = nn.GRUCell(self.deter_size, self.deter_size)
self.fc_state_action = nn.Linear(self.stoch_size + action_size, self.deter_size)
self.fc_embed_prior = nn.Linear(self.deter_size, self.hidden_size)
self.fc_state_prior = nn.Linear(self.hidden_size, 2 * self.stoch_size)
self.fc_embed_posterior = nn.Linear(
self.embed_size + self.deter_size, self.hidden_size
)
self.fc_state_posterior = nn.Linear(self.hidden_size, 2 * self.stoch_size)
@property
def embedding_size(self) -> int:
return self.embed_size
[docs]
def init_state(
self, batch_size: int, device: torch.device
) -> Dict[str, torch.Tensor]:
return {
"mean": torch.zeros(batch_size, self.stoch_size).to(device),
"std": torch.zeros(batch_size, self.stoch_size).to(device),
"stoch": torch.zeros(batch_size, self.stoch_size).to(device),
"deter": torch.zeros(batch_size, self.deter_size).to(device),
}
[docs]
def forward(
self,
state: Dict[str, torch.Tensor],
action: torch.Tensor,
obs_embed: Optional[torch.Tensor] = None,
nonterm: float = 1.0,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
prior = self._imagine_step(state, action, nonterm)
if obs_embed is not None:
posterior_embed = self.act_fn(
self.fc_embed_posterior(torch.cat([obs_embed, prior["deter"]], dim=-1))
)
posterior = self.fc_state_posterior(posterior_embed)
mean, std = torch.chunk(posterior, 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
posterior = {
"mean": mean,
"std": std,
"stoch": sample,
"deter": prior["deter"],
}
else:
posterior = prior
return prior, posterior
def _imagine_step(
self, state: Dict[str, torch.Tensor], action: torch.Tensor, nonterm: float = 1.0
) -> Dict[str, torch.Tensor]:
state_action = self.act_fn(
self.fc_state_action(torch.cat([state["stoch"] * nonterm, action], dim=-1))
)
deter = self.rnn(state_action, state["deter"] * nonterm)
prior_embed = self.act_fn(self.fc_embed_prior(deter))
mean, std = torch.chunk(self.fc_state_prior(prior_embed), 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
return {"mean": mean, "std": std, "stoch": sample, "deter": deter}
[docs]
class LSTMBackbone(BackboneBase):
"""LSTM-based recurrent dynamics backbone."""
def __init__(
self,
action_size: int,
stoch_size: int,
deter_size: int,
hidden_size: int,
embed_size: int,
activation: str = "elu",
):
super().__init__()
self.action_size = action_size
self.stoch_size = stoch_size
self.deter_size = deter_size
self.hidden_size = hidden_size
self.embed_size = embed_size
self.act_fn = _str_to_activation[activation]
self.rnn = nn.LSTMCell(self.deter_size, self.deter_size)
self.fc_state_action = nn.Linear(self.stoch_size + action_size, self.deter_size)
self.fc_embed_prior = nn.Linear(self.deter_size, self.hidden_size)
self.fc_state_prior = nn.Linear(self.hidden_size, 2 * self.stoch_size)
self.fc_embed_posterior = nn.Linear(
self.embed_size + self.deter_size, self.hidden_size
)
self.fc_state_posterior = nn.Linear(self.hidden_size, 2 * self.stoch_size)
@property
def embedding_size(self) -> int:
return self.embed_size
[docs]
def init_state(
self, batch_size: int, device: torch.device
) -> Dict[str, torch.Tensor]:
return {
"mean": torch.zeros(batch_size, self.stoch_size).to(device),
"std": torch.zeros(batch_size, self.stoch_size).to(device),
"stoch": torch.zeros(batch_size, self.stoch_size).to(device),
"deter": torch.zeros(batch_size, self.deter_size).to(device),
"cell": torch.zeros(batch_size, self.deter_size).to(device),
}
[docs]
def forward(
self,
state: Dict[str, torch.Tensor],
action: torch.Tensor,
obs_embed: Optional[torch.Tensor] = None,
nonterm: float = 1.0,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
prior = self._imagine_step(state, action, nonterm)
if obs_embed is not None:
posterior_embed = self.act_fn(
self.fc_embed_posterior(torch.cat([obs_embed, prior["deter"]], dim=-1))
)
posterior = self.fc_state_posterior(posterior_embed)
mean, std = torch.chunk(posterior, 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
posterior = {
"mean": mean,
"std": std,
"stoch": sample,
"deter": prior["deter"],
"cell": prior.get(
"cell", state.get("cell", torch.zeros_like(prior["deter"]))
),
}
else:
posterior = prior
return prior, posterior
def _imagine_step(
self, state: Dict[str, torch.Tensor], action: torch.Tensor, nonterm: float = 1.0
) -> Dict[str, torch.Tensor]:
state_action = self.act_fn(
self.fc_state_action(torch.cat([state["stoch"] * nonterm, action], dim=-1))
)
h, c = self.rnn(
state_action,
(
state["deter"] * nonterm,
state.get("cell", torch.zeros_like(state["deter"])) * nonterm,
),
)
prior_embed = self.act_fn(self.fc_embed_prior(h))
mean, std = torch.chunk(self.fc_state_prior(prior_embed), 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
return {"mean": mean, "std": std, "stoch": sample, "deter": h, "cell": c}
[docs]
class TransformerBackbone(BackboneBase):
"""Transformer-based dynamics backbone for long-range dependencies."""
def __init__(
self,
action_size: int,
stoch_size: int,
deter_size: int,
embed_size: int,
num_heads: int = 4,
num_layers: int = 2,
activation: str = "gelu",
):
super().__init__()
self.action_size = action_size
self.stoch_size = stoch_size
self.deter_size = deter_size
self.embed_size = embed_size
self.act_fn = _str_to_activation[activation]
self.action_embed = nn.Linear(action_size, embed_size)
self.stoch_embed = nn.Linear(stoch_size, embed_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_size,
nhead=num_heads,
dim_feedforward=deter_size,
activation=activation,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc_embed_prior = nn.Linear(embed_size, deter_size)
self.fc_state_prior = nn.Linear(deter_size, 2 * stoch_size)
self.fc_embed_posterior = nn.Linear(embed_size + deter_size, deter_size)
self.fc_state_posterior = nn.Linear(deter_size, 2 * stoch_size)
@property
def embedding_size(self) -> int:
return self.embed_size
[docs]
def init_state(
self, batch_size: int, device: torch.device
) -> Dict[str, torch.Tensor]:
return {
"mean": torch.zeros(batch_size, self.stoch_size).to(device),
"std": torch.zeros(batch_size, self.stoch_size).to(device),
"stoch": torch.zeros(batch_size, self.stoch_size).to(device),
"deter": torch.zeros(batch_size, self.deter_size).to(device),
}
[docs]
def forward(
self,
state: Dict[str, torch.Tensor],
action: torch.Tensor,
obs_embed: Optional[torch.Tensor] = None,
nonterm: float = 1.0,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
prior = self._imagine_step(state, action, nonterm)
if obs_embed is not None:
# Handle both 2D (batch, embed) and 3D (seq, batch, embed) embeddings
if obs_embed.dim() == 3:
# Sequence data: need to handle matching dimensions
seq_len = obs_embed.shape[0]
prior_deter_expanded = (
prior["deter"].unsqueeze(0).expand(seq_len, -1, -1)
)
else:
# Single step: 2D embeddings
prior_deter_expanded = prior["deter"]
posterior_embed = self.act_fn(
self.fc_embed_posterior(
torch.cat([obs_embed, prior_deter_expanded], dim=-1)
)
)
posterior = self.fc_state_posterior(posterior_embed)
mean, std = torch.chunk(posterior, 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
posterior = {
"mean": mean,
"std": std,
"stoch": sample,
"deter": prior["deter"],
}
else:
posterior = prior
return prior, posterior
def _imagine_step(
self, state: Dict[str, torch.Tensor], action: torch.Tensor, nonterm: float = 1.0
) -> Dict[str, torch.Tensor]:
action_emb = self.action_embed(action)
stoch_emb = self.stoch_embed(state["stoch"] * nonterm)
x = action_emb + stoch_emb
x = x.unsqueeze(0) if x.dim() == 2 else x
x = self.transformer(x)
h = x.squeeze(0) if x.shape[0] == 1 else x
prior_embed = self.act_fn(self.fc_embed_prior(h))
mean, std = torch.chunk(self.fc_state_prior(prior_embed), 2, dim=-1)
std = F.softplus(std) + 0.1
sample = mean + torch.randn_like(mean) * std
return {"mean": mean, "std": std, "stoch": sample, "deter": prior_embed}
[docs]
class ModularRSSM(nn.Module):
"""Modular RSSM with swappable encoder, decoder, and backbone.
This class allows researchers to easily experiment with different:
- Encoders: Conv, MLP, ViT
- Decoders: Conv, MLP
- Backbones: GRU, LSTM, Transformer
Example:
>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024)
>>> decoder = ConvDecoder(32, 200, (3, 64, 64))
>>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024)
>>> rssm = ModularRSSM(encoder, decoder, backbone)
"""
def __init__(
self,
encoder: EncoderBase,
decoder: DecoderBase,
backbone: BackboneBase,
reward_decoder: Optional[DecoderBase] = None,
) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.backbone = backbone
self.reward_decoder = reward_decoder
@property
def stoch_size(self) -> int:
return self.backbone.stoch_size
@property
def deter_size(self) -> int:
return self.backbone.deter_size
@property
def embed_size(self) -> int:
# Some backbones may expose embedding_size as an int attribute or
# property. Cast to int to make the return type explicit for
# type-checkers.
return int(self.backbone.embedding_size) # type: ignore[arg-type]
[docs]
def init_state(
self, batch_size: int, device: torch.device
) -> Dict[str, torch.Tensor]:
return self.backbone.init_state(batch_size, device)
[docs]
def get_dist(
self, mean: torch.Tensor, std: torch.Tensor
) -> distributions.Distribution:
distribution = distributions.Normal(mean, std)
return distributions.independent.Independent(distribution, 1)
[docs]
def observe_step(
self,
prev_state: Dict[str, torch.Tensor],
prev_action: torch.Tensor,
obs: torch.Tensor,
nonterm: Any = 1.0,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
obs_embed = self.encoder(obs)
prior, posterior = self.backbone.forward(
prev_state, prev_action, obs_embed, nonterm
)
return prior, posterior
[docs]
def imagine_step(
self,
prev_state: Dict[str, torch.Tensor],
prev_action: torch.Tensor,
nonterm: Any = 1.0,
) -> Dict[str, torch.Tensor]:
prior, _ = self.backbone.forward(prev_state, prev_action, None, nonterm)
return prior
[docs]
def observe_rollout(
self,
obs: torch.Tensor,
actions: torch.Tensor,
nonterms: torch.Tensor,
prev_state: Dict[str, torch.Tensor],
horizon: int,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
priors = []
posteriors = []
for t in range(horizon):
nonterm_t = nonterms[t]
if nonterm_t.dim() > 1:
nonterm_t = nonterm_t.squeeze(-1)
elif nonterm_t.dim() == 1:
nonterm_t = nonterm_t.unsqueeze(-1)
prev_action = actions[t] * nonterm_t
prior_state, posterior_state = self.observe_step(
prev_state, prev_action, obs[t], 1.0
)
priors.append(prior_state)
posteriors.append(posterior_state)
prev_state = posterior_state
return self._stack_states(priors), self._stack_states(posteriors)
[docs]
def imagine_rollout(
self,
actor: nn.Module,
prev_state: Dict[str, torch.Tensor],
horizon: int,
) -> Dict[str, torch.Tensor]:
rssm_state = prev_state
next_states = []
for _ in range(horizon):
features = torch.cat(
[rssm_state["stoch"], rssm_state["deter"]], dim=-1
).detach()
action = actor(features)
rssm_state = self.imagine_step(rssm_state, action)
next_states.append(rssm_state)
return self._stack_states(next_states)
[docs]
def decode_reward(self, features: torch.Tensor):
if self.reward_decoder is None:
raise ValueError("Reward decoder not provided")
return self.reward_decoder(features)
def _stack_states(
self, states: List[Dict[str, torch.Tensor]], dim: int = 0
) -> Dict[str, torch.Tensor]:
return {
"mean": torch.stack([s["mean"] for s in states], dim=dim),
"std": torch.stack([s["std"] for s in states], dim=dim),
"stoch": torch.stack([s["stoch"] for s in states], dim=dim),
"deter": torch.stack([s["deter"] for s in states], dim=dim),
}
[docs]
def detach_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {k: v.detach() for k, v in state.items()}
[docs]
def seq_to_batch(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return {
k: v.reshape(v.shape[0] * v.shape[1], *v.shape[2:]) if v.dim() > 2 else v
for k, v in state.items()
}
[docs]
def create_modular_rssm(
encoder_type: str = "conv",
decoder_type: str = "conv",
backbone_type: str = "gru",
obs_shape: Union[Tuple[int, int, int], Tuple[int]] = (3, 64, 64),
action_size: int = 6,
stoch_size: int = 32,
deter_size: int = 200,
embed_size: int = 1024,
hidden_size: int = 200,
activation: str = "elu",
**kwargs,
) -> ModularRSSM:
"""Factory function to create a modular RSSM with specified components.
Args:
encoder_type: Type of encoder ("conv", "mlp", "vit")
decoder_type: Type of decoder ("conv", "mlp")
backbone_type: Type of backbone ("gru", "lstm", "transformer")
obs_shape: Shape of observations (C, H, W) for images or (D,) for state
action_size: Action space dimension
stoch_size: Stochastic latent dimension
deter_size: Deterministic hidden dimension
embed_size: Encoder embedding dimension
hidden_size: Hidden layer dimension
activation: Activation function name
Returns:
Configured ModularRSSM instance
"""
image_shape: Tuple[int, int, int] = (3, 64, 64) # type: ignore
if len(obs_shape) == 3:
image_shape = (int(obs_shape[0]), int(obs_shape[1]), int(obs_shape[2])) # type: ignore
# Declare with base types so mypy understands the variable may hold any
# concrete implementation chosen below.
encoder: EncoderBase
if encoder_type == "conv":
encoder = ConvEncoder(image_shape, embed_size, activation)
elif encoder_type == "mlp":
encoder = MLPEncoder(obs_shape[0], embed_size)
elif encoder_type == "vit":
encoder = ViTEncoder(image_shape, embed_size)
else:
raise ValueError(f"Unknown encoder type: {encoder_type}")
decoder: DecoderBase
if decoder_type == "conv":
decoder = ConvDecoder(stoch_size, deter_size, image_shape, activation)
elif decoder_type == "mlp":
decoder = MLPDecoder(
stoch_size, deter_size, obs_shape[0], activation=activation
)
else:
raise ValueError(f"Unknown decoder type: {decoder_type}")
backbone: BackboneBase
if backbone_type == "gru":
backbone = GRUBackbone(
action_size, stoch_size, deter_size, hidden_size, embed_size, activation
)
elif backbone_type == "lstm":
backbone = LSTMBackbone(
action_size, stoch_size, deter_size, hidden_size, embed_size, activation
)
elif backbone_type == "transformer":
num_heads = kwargs.get("num_heads", 4)
num_layers = kwargs.get("num_layers", 2)
backbone_activation = "gelu" # PyTorch Transformer only supports relu/gelu
backbone = TransformerBackbone(
action_size,
stoch_size,
deter_size,
embed_size,
num_heads,
num_layers,
backbone_activation,
)
else:
raise ValueError(f"Unknown backbone type: {backbone_type}")
reward_decoder = MLPDecoder(
stoch_size, deter_size, 1, activation=activation, dist="normal"
)
return ModularRSSM(encoder, decoder, backbone, reward_decoder)