import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
[docs]
class IRISActor(nn.Module):
"""Actor network for IRIS (Imagined Rollouts with Implicit Successor) policy.
Takes reconstructed frames as input and outputs action logits for policy control.
Uses a CNN feature extractor followed by an LSTM for temporal processing.
Supports a burn-in mechanism for initializing the hidden state with context frames.
Architecture:
- CNN: Extracts features from input frames (3x64x64 -> 512)
- LSTM: Processes temporal sequences with configurable layers
- Linear: Maps hidden states to action logits
Args:
action_size (int): Number of discrete actions.
hidden_size (int): LSTM hidden state size (default: 512).
num_layers (int): Number of LSTM layers (default: 4).
frame_shape (tuple): Shape of input frames as (C, H, W) (default: (3, 64, 64)).
Attributes:
action_size (int): Number of discrete actions.
hidden_size (int): LSTM hidden state size.
num_layers (int): Number of LSTM layers.
frame_shape (tuple): Input frame shape.
"""
def __init__(
self,
action_size: int,
hidden_size: int = 512,
num_layers: int = 4,
frame_shape: Tuple[int, int, int] = (3, 64, 64),
):
super().__init__()
self.action_size = action_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.frame_shape = frame_shape
# CNN feature extractor (shared with critic)
self.cnn = CNNFeatureExtractor(frame_shape)
# LSTM for temporal processing
self.lstm = nn.LSTM(
input_size=self.cnn.output_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
)
# Action head
self.action_head = nn.Linear(hidden_size, action_size)
[docs]
def forward(
self,
frames: torch.Tensor, # (B, T, C, H, W) or (B, C, H, W)
hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
burn_in_frames: Optional[torch.Tensor] = None, # (B, burn_in, C, H, W)
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward pass through actor.
Args:
frames: Input frames (B, T, C, H, W) or (B, C, H, W)
hidden_state: Optional (h, c) tuple for LSTM state
burn_in_frames: Frames to use for initializing hidden state
Returns:
action_logits: Action logits (B, T, action_size) or (B, action_size)
hidden_state: Updated (h, c) tuple
"""
# Handle different input shapes
if frames.dim() == 4: # (B, C, H, W)
frames = frames.unsqueeze(1) # (B, 1, C, H, W)
squeeze_output = True
else:
squeeze_output = False
B, T, C, H, W = frames.shape
# Process each frame through CNN
frames_flat = frames.reshape(B * T, C, H, W)
features = self.cnn(frames_flat) # (B*T, feature_size)
features = features.reshape(B, T, -1) # (B, T, feature_size)
# Burn-in: initialize hidden state with past frames
if burn_in_frames is not None:
B_burn, T_burn, C_burn, H_burn, W_burn = burn_in_frames.shape
burn_features = self.cnn(
burn_in_frames.reshape(B_burn * T_burn, C_burn, H_burn, W_burn)
)
burn_features = burn_features.reshape(B_burn, T_burn, -1)
# Initialize LSTM hidden state
_, hidden_state = self.lstm(burn_features)
# Process sequence through LSTM
if hidden_state is None:
hidden_state = self.init_hidden_state(B, frames.device)
lstm_out, hidden_state = self.lstm(features, hidden_state)
# Get action logits
action_logits = self.action_head(lstm_out) # (B, T, action_size)
if squeeze_output:
action_logits = action_logits.squeeze(1) # (B, action_size)
hidden_state = (hidden_state[0].squeeze(0), hidden_state[1].squeeze(0))
return action_logits, hidden_state
[docs]
def init_hidden_state(
self,
batch_size: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize LSTM hidden state."""
h = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
c = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
return (h, c)
[docs]
def get_action(
self,
frame: torch.Tensor,
temperature: float = 1.0,
deterministic: bool = False,
) -> torch.Tensor:
"""Get action from a single frame.
Args:
frame: Single frame (B, C, H, W)
temperature: Softmax temperature (higher = more random)
deterministic: If True, return argmax; else sample
Returns:
action: Selected action indices (B,)
"""
self.eval()
with torch.no_grad():
action_logits, _ = self.forward(frame)
action_logits = action_logits / temperature
if deterministic:
action = action_logits.argmax(dim=-1)
else:
probs = F.softmax(action_logits, dim=-1)
action = torch.multinomial(probs, 1).squeeze(-1)
return action
[docs]
class IRISCritic(nn.Module):
"""Critic network for IRIS value estimation.
Estimates the value function for given frame sequences. Shares the CNN
feature extractor and LSTM backbone with the actor for efficiency, but
has a separate value head for estimating expected cumulative rewards.
Architecture:
- CNN: Shared feature extractor with actor (3x64x64 -> 512)
- LSTM: Temporal processing with same architecture as actor
- Linear: Maps hidden states to scalar values
Args:
hidden_size (int): LSTM hidden state size (default: 512).
num_layers (int): Number of LSTM layers (default: 4).
frame_shape (tuple): Shape of input frames as (C, H, W) (default: (3, 64, 64)).
Attributes:
hidden_size (int): LSTM hidden state size.
num_layers (int): Number of LSTM layers.
frame_shape (tuple): Input frame shape.
Returns:
values: Value estimates with shape (B, T).
hidden_state: Updated LSTM hidden state (h, c) tuple.
"""
def __init__(
self,
hidden_size: int = 512,
num_layers: int = 4,
frame_shape: Tuple[int, int, int] = (3, 64, 64),
):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.frame_shape = frame_shape
# CNN feature extractor (shared with actor)
self.cnn = CNNFeatureExtractor(frame_shape)
# LSTM for temporal processing
self.lstm = nn.LSTM(
input_size=self.cnn.output_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
)
# Value head
self.value_head = nn.Linear(hidden_size, 1)
[docs]
def forward(
self,
frames: torch.Tensor, # (B, T, C, H, W)
hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward pass through critic.
Args:
frames: Input frames (B, T, C, H, W)
hidden_state: Optional (h, c) tuple
Returns:
values: Value estimates (B, T)
hidden_state: Updated (h, c) tuple
"""
B, T, C, H, W = frames.shape
# CNN features
frames_flat = frames.reshape(B * T, C, H, W)
features = self.cnn(frames_flat)
features = features.reshape(B, T, -1)
# LSTM
if hidden_state is None:
hidden_state = self.init_hidden_state(B, frames.device)
lstm_out, hidden_state = self.lstm(features, hidden_state)
# Value
values = self.value_head(lstm_out).squeeze(-1) # (B, T)
return values, hidden_state
[docs]
def init_hidden_state(
self,
batch_size: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize LSTM hidden state."""
h = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
c = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
return (h, c)
[docs]
class IRISPolicy(nn.Module):
"""Combined policy module for IRIS (Imagined Rollouts with Implicit Successor).
Provides a unified interface for actor-only or actor-critic policies.
Used in the IRIS algorithm where the actor generates actions from reconstructed
frames and the critic estimates value functions for training.
Args:
action_size (int): Number of discrete actions.
hidden_size (int): LSTM hidden state size (default: 512).
num_layers (int): Number of LSTM layers (default: 4).
frame_shape (tuple): Shape of input frames as (C, H, W) (default: (3, 64, 64)).
Attributes:
actor (IRISActor): The actor network for action selection.
hidden_size (int): LSTM hidden state size.
num_layers (int): Number of LSTM layers.
frame_shape (tuple): Input frame shape.
Example:
>>> policy = IRISPolicy(
... action_size=18,
... hidden_size=512,
... num_layers=4,
... frame_shape=(3, 64, 64)
... )
>>> action = policy.act(frame, temperature=1.0, deterministic=False)
"""
def __init__(
self,
action_size: int,
hidden_size: int = 512,
num_layers: int = 4,
frame_shape: Tuple[int, int, int] = (3, 64, 64),
):
super().__init__()
self.actor = IRISActor(
action_size=action_size,
hidden_size=hidden_size,
num_layers=num_layers,
frame_shape=frame_shape,
)
[docs]
def forward(self, frames: torch.Tensor) -> torch.Tensor:
"""Get action logits from frames."""
action_logits, _ = self.actor(frames)
return action_logits
[docs]
def act(
self,
frame: torch.Tensor,
temperature: float = 1.0,
deterministic: bool = False,
) -> torch.Tensor:
"""Sample action from policy."""
return self.actor.get_action(frame, temperature, deterministic)
[docs]
def init_hidden(self, batch_size: int, device: torch.device):
"""Initialize hidden state."""
return self.actor.init_hidden_state(batch_size, device)