import torch
import torch.nn as nn
import torch.optim as optim
from typing import Tuple, Optional
import torch.nn.functional as F
from world_models.configs.iris_config import IRISConfig
from world_models.vision.iris_encoder import IRISEncoder
from world_models.vision.iris_decoder import IRISDecoder
from world_models.models.iris_transformer import IRISTransformer
from world_models.controller.iris_policy import (
CNNFeatureExtractor,
)
[docs]
def compute_lambda_return(
rewards: torch.Tensor,
values: torch.Tensor,
discounts: torch.Tensor,
lambda_coef: float = 0.95,
) -> torch.Tensor:
"""Compute λ-return target for value function training.
Args:
rewards: Rewards (B, T)
values: Value estimates (B, T+1)
discounts: Discount factors (B, T)
lambda_coef: Lambda parameter for bootstrapping
Returns:
lambda_returns: λ-return targets (B, T)
"""
T = rewards.shape[1]
lambda_returns = torch.zeros_like(rewards)
# Start with the last bootstrapped value
lambda_returns[:, T - 1] = rewards[:, T - 1] + discounts[:, T - 1] * values[:, T]
# Compute λ-returns backwards
for t in range(T - 2, -1, -1):
lambda_returns[:, t] = rewards[:, t] + discounts[:, t] * (
(1 - lambda_coef) * values[:, t + 1]
+ lambda_coef * lambda_returns[:, t + 1]
)
return lambda_returns
[docs]
class IRISAgent(nn.Module):
"""Complete IRIS Agent with world model and policy.
Combines:
- Discrete autoencoder (encoder + decoder)
- Transformer world model
- Actor-Critic for policy and value learning
"""
def __init__(
self,
config: IRISConfig,
action_size: int,
device: torch.device,
):
super().__init__()
self.config = config
self.action_size = action_size
self.device = device
# === Discrete Autoencoder ===
self.encoder = IRISEncoder(
vocab_size=config.vocab_size,
tokens_per_frame=config.tokens_per_frame,
embedding_dim=config.token_embedding_dim,
in_channels=config.frame_channels,
base_channels=config.encoder_channels,
frame_shape=config.get_frame_shape(),
).to(device)
self.decoder = IRISDecoder(
vocab_size=config.vocab_size,
embedding_dim=config.token_embedding_dim,
base_channels=config.decoder_depth,
out_channels=config.frame_channels,
frame_shape=config.get_frame_shape(),
).to(device)
# === Transformer World Model ===
self.transformer = IRISTransformer(
vocab_size=config.vocab_size,
tokens_per_frame=config.tokens_per_frame,
action_size=action_size,
embed_dim=config.transformer_embed_dim,
num_layers=config.transformer_layers,
num_heads=config.transformer_heads,
dropout=config.transformer_dropout,
).to(device)
# === Actor-Critic ===
# Combine actor and critic with shared CNN features
self.cnn = CNNFeatureExtractor(
frame_shape=config.get_frame_shape(),
output_size=config.actor_hidden_size,
).to(device)
self.lstm = nn.LSTM(
input_size=config.actor_hidden_size,
hidden_size=config.actor_hidden_size,
num_layers=config.actor_layers,
batch_first=True,
).to(device)
self.actor_head = nn.Linear(config.actor_hidden_size, action_size).to(device)
self.critic_head = nn.Linear(config.actor_hidden_size, 1).to(device)
# === Optimizers ===
self._setup_optimizers()
# === Training state ===
self.global_step = 0
self.current_epoch = 0
def _setup_optimizers(self):
"""Setup separate optimizers for each component."""
# Autoencoder optimizer
self.autoencoder_opt = optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters()),
lr=self.config.model_learning_rate,
betas=(self.config.adam_beta1, self.config.adam_beta2),
weight_decay=self.config.weight_decay,
)
# Transformer optimizer
self.transformer_opt = optim.Adam(
self.transformer.parameters(),
lr=self.config.model_learning_rate,
betas=(self.config.adam_beta1, self.config.adam_beta2),
weight_decay=self.config.weight_decay,
)
# Actor-Critic optimizer
ac_params = (
list(self.cnn.parameters())
+ list(self.lstm.parameters())
+ list(self.actor_head.parameters())
+ list(self.critic_head.parameters())
)
self.ac_opt = optim.Adam(
ac_params,
lr=self.config.actor_learning_rate,
betas=(self.config.adam_beta1, self.config.adam_beta2),
weight_decay=self.config.weight_decay,
)
[docs]
def forward_actor_critic(
self,
frames: torch.Tensor, # (B, T, C, H, W)
hidden: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward pass through actor-critic.
Args:
frames: Input frames (B, T, C, H, W)
hidden: Optional LSTM hidden state
Returns:
action_logits: (B, T, action_size)
values: (B, T)
hidden_state: (h, c)
"""
B, T, C, H, W = frames.shape
# CNN features
frames_flat = frames.reshape(B * T, C, H, W)
features = self.cnn(frames_flat)
features = features.reshape(B, T, -1)
# LSTM
if hidden is None:
hidden = self._init_lstm_hidden(B)
lstm_out, hidden = self.lstm(features, hidden)
# Action and value
action_logits = self.actor_head(lstm_out)
values = self.critic_head(lstm_out).squeeze(-1)
return action_logits, values, hidden
def _init_lstm_hidden(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize LSTM hidden state."""
h = torch.zeros(
self.config.actor_layers,
batch_size,
self.config.actor_hidden_size,
device=self.device,
)
c = torch.zeros(
self.config.actor_layers,
batch_size,
self.config.actor_hidden_size,
device=self.device,
)
return (h, c)
[docs]
def act(
self,
frame: torch.Tensor,
epsilon: float = 0.0,
temperature: float = 1.0,
) -> torch.Tensor:
"""Sample action from policy.
Args:
frame: Single frame (B, C, H, W)
epsilon: Random action probability
temperature: Action distribution temperature
Returns:
actions: Selected actions (B,)
"""
self.eval()
with torch.no_grad():
B = frame.shape[0]
frames = frame.unsqueeze(1) # (B, 1, C, H, W)
action_logits, _, _ = self.forward_actor_critic(frames)
action_logits = action_logits.squeeze(1) / temperature
# Epsilon-greedy
if epsilon > 0:
random_mask = torch.rand(B) < epsilon
random_actions = torch.randint(
0, self.action_size, (B,), device=self.device
)
greedy_actions = action_logits.argmax(dim=-1)
actions = torch.where(random_mask, random_actions, greedy_actions)
else:
probs = torch.softmax(action_logits, dim=-1)
actions = torch.multinomial(probs, 1).squeeze(-1)
return actions
[docs]
def imagine_rollout(
self,
initial_frame: torch.Tensor,
horizon: int = 20,
) -> dict:
"""Generate imagined trajectories using world model.
Args:
initial_frame: Starting frame (B, C, H, W)
horizon: Number of steps to imagine
Returns:
trajectory: Dictionary with imagined rollout data
"""
self.eval()
# Encode initial frame
with torch.no_grad():
_, initial_tokens, _ = self.encoder(initial_frame)
# Lists to store trajectory
frames_imagined = [initial_frame]
actions_imagined = []
rewards_imagined = []
current_tokens = initial_tokens
for step in range(horizon):
# Decode current tokens to get "observation"
with torch.no_grad():
reconstructed_frame = self.decoder(
self.encoder.quantizer.decode_indices(current_tokens)
)
# Get action from policy
with torch.no_grad():
action = self.act(reconstructed_frame, epsilon=0.0)
# Predict next tokens
with torch.no_grad():
token_logits, action_hidden = self.transformer.predict_next_tokens(
current_tokens, action
)
next_tokens = torch.argmax(token_logits, dim=-1)
# Reshape from (B, K) to (B, H, W) for decode_indices
tokens_per_dim = int(self.config.tokens_per_frame**0.5)
next_tokens = next_tokens.reshape(
next_tokens.shape[0],
tokens_per_dim,
tokens_per_dim,
)
# Get reward prediction
reward_pred = self.transformer.reward_head(action_hidden).squeeze(-1)
actions_imagined.append(action)
rewards_imagined.append(reward_pred)
frames_imagined.append(reconstructed_frame)
current_tokens = next_tokens
return {
"frames": torch.stack(frames_imagined, dim=1), # (B, H+1, C, H, W)
"actions": (
torch.stack(actions_imagined, dim=1) if actions_imagined else None
),
"rewards": (
torch.stack(rewards_imagined, dim=1) if rewards_imagined else None
),
}
[docs]
def update_autoencoder(self, frames: torch.Tensor) -> dict:
"""Update discrete autoencoder.
Args:
frames: Training frames (B, C, H, W)
Returns:
losses: Dictionary of loss values
"""
self.encoder.train()
self.decoder.train()
# Encode
z_q, indices, vq_loss = self.encoder(frames)
# Decode
reconstruction = self.decoder(z_q)
# Compute losses
recon_loss = F.l1_loss(reconstruction, frames)
loss = recon_loss + vq_loss["vq_loss"]
# Update
self.autoencoder_opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
list(self.encoder.parameters()) + list(self.decoder.parameters()),
self.config.grad_clip_norm,
)
self.autoencoder_opt.step()
return {
"recon_loss": recon_loss.item(),
"vq_loss": vq_loss["vq_loss"].item(),
"perplexity": vq_loss["perplexity"].item(),
"total_loss": loss.item(),
}
[docs]
def update_actor_critic(
self,
imagined_trajectory: dict,
) -> dict:
"""Update actor-critic in imagination.
Args:
imagined_trajectory: Dictionary from imagine_rollout
Returns:
losses: Dictionary of loss values
"""
self.train()
frames = imagined_trajectory["frames"] # (B, T+1, C, H, W)
actions = imagined_trajectory["actions"] # (B, T)
rewards = imagined_trajectory["rewards"] # (B, T)
B, T_plus_1, C, H, W = frames.shape
# Forward pass
action_logits, values, _ = self.forward_actor_critic(
frames[:, :-1]
) # (B, T, A), (B, T)
# Compute log probabilities
action_dist = torch.softmax(action_logits, dim=-1)
action_log_probs = torch.log(action_dist + 1e-8)
# Gather log probs for taken actions
actions_one_hot = F.one_hot(actions, self.action_size).float()
taken_log_probs = (action_log_probs * actions_one_hot).sum(dim=-1) # (B, T)
# Compute λ-returns
discounts = torch.full_like(rewards, self.config.discount)
lambda_returns = compute_lambda_return(
rewards,
torch.cat([values, torch.zeros(B, 1, device=self.device)], dim=1),
discounts,
self.config.td_lambda,
)
# Advantage
advantages = lambda_returns - values # (B, T)
# Actor loss (REINFORCE with baseline)
actor_loss = -(taken_log_probs * advantages.detach()).mean()
# Entropy bonus
entropy = -(action_dist * action_log_probs).sum(dim=-1).mean()
actor_loss -= self.config.entropy_coef * entropy
# Critic loss
value_loss = F.mse_loss(values, lambda_returns.detach())
# Total loss
loss = actor_loss + 0.5 * value_loss
# Update
self.ac_opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
list(self.cnn.parameters())
+ list(self.lstm.parameters())
+ list(self.actor_head.parameters())
+ list(self.critic_head.parameters()),
self.config.grad_clip_norm,
)
self.ac_opt.step()
return {
"actor_loss": actor_loss.item(),
"value_loss": value_loss.item(),
"entropy": entropy.item(),
"total_loss": loss.item(),
}
[docs]
def save(self, path: str):
"""Save agent state."""
torch.save(
{
"encoder": self.encoder.state_dict(),
"decoder": self.decoder.state_dict(),
"transformer": self.transformer.state_dict(),
"cnn": self.cnn.state_dict(),
"lstm": self.lstm.state_dict(),
"actor_head": self.actor_head.state_dict(),
"critic_head": self.critic_head.state_dict(),
"autoencoder_opt": self.autoencoder_opt.state_dict(),
"transformer_opt": self.transformer_opt.state_dict(),
"ac_opt": self.ac_opt.state_dict(),
"config": self.config,
"global_step": self.global_step,
"epoch": self.current_epoch,
},
path,
)
[docs]
def load(self, path: str):
"""Load agent state."""
with torch.serialization.safe_globals([IRISConfig]):
checkpoint = torch.load(path, map_location=self.device)
self.encoder.load_state_dict(checkpoint["encoder"])
self.decoder.load_state_dict(checkpoint["decoder"])
self.transformer.load_state_dict(checkpoint["transformer"])
self.cnn.load_state_dict(checkpoint["cnn"])
self.lstm.load_state_dict(checkpoint["lstm"])
self.actor_head.load_state_dict(checkpoint["actor_head"])
self.critic_head.load_state_dict(checkpoint["critic_head"])
self.autoencoder_opt.load_state_dict(checkpoint["autoencoder_opt"])
self.transformer_opt.load_state_dict(checkpoint["transformer_opt"])
self.ac_opt.load_state_dict(checkpoint["ac_opt"])
self.global_step = checkpoint.get("global_step", 0)
self.current_epoch = checkpoint.get("epoch", 0)