Source code for world_models.models.diffusion.actor_critic

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional


[docs] class ActorCriticNetwork(nn.Module): """ Actor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads. """ def __init__( self, obs_channels: int = 3, action_dim: int = 18, channels: Tuple[int, ...] = (32, 32, 64, 64), lstm_dim: int = 512, ): super().__init__() self.obs_channels = obs_channels self.action_dim = action_dim self.conv_blocks = nn.ModuleList() in_ch = obs_channels for i, out_ch in enumerate(channels): self.conv_blocks.append( nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1), nn.GroupNorm(8, out_ch), nn.SiLU(), ) ) in_ch = out_ch self.lstm = nn.LSTM( input_size=channels[-1], hidden_size=lstm_dim, num_layers=1, batch_first=True, ) self.policy_head = nn.Linear(lstm_dim, action_dim) self.value_head = nn.Linear(lstm_dim, 1)
[docs] def forward( self, obs: torch.Tensor, hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Forward pass of actor-critic network. Args: obs: Observations [B, T, C, H, W] hidden_state: Optional (h, c) hidden states Returns: policy_logits: [B, T, action_dim] values: [B, T, 1] hidden_state: (h, c) """ B, T, C, H, W = obs.shape obs_flat = obs.view(B * T, C, H, W) h = obs_flat for conv_block in self.conv_blocks: h = conv_block(h) h = h.mean(dim=[2, 3]) h = h.view(B, T, -1) if hidden_state is None: lstm_out, hidden_state = self.lstm(h) else: lstm_out, hidden_state = self.lstm(h, hidden_state) policy_logits = self.policy_head(lstm_out) values = self.value_head(lstm_out) return policy_logits, values, hidden_state
[docs] def get_action( self, obs: torch.Tensor, hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, deterministic: bool = False, ) -> Tuple[int, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Get action from a single observation. Args: obs: Single observation [B, C, H, W] hidden_state: Optional (h, c) hidden states deterministic: If True, take argmax; else sample Returns: action: Selected action [B] hidden_state: (h, c) """ # delegate to the batched interface for a single-sample convenience actions, hidden_state = self.get_actions( obs.unsqueeze(0), hidden_state, deterministic=deterministic ) # get scalar action for the single sample return int(actions[0].item()), hidden_state
[docs] def get_actions( self, obs: torch.Tensor, hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, deterministic: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Batched version of get_action. Args: obs: Tensor of shape [B, C, H, W] hidden_state: Optional LSTM hidden state tuple matching batch size deterministic: If True, take argmax; else sample from policy Returns: actions: LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple """ # convert to [B, T=1, C, H, W] expected by forward if obs.ndim == 4: obs_in = obs.unsqueeze(1) elif obs.ndim == 5: obs_in = obs else: raise ValueError(f"Unexpected obs ndim {obs.ndim}") policy_logits, values, hidden_state = self.forward(obs_in, hidden_state) # policy_logits: [B, T, A] -> take last timestep logits = policy_logits[:, -1, :] if deterministic: actions = logits.argmax(dim=-1) else: probs = F.softmax(logits, dim=-1) # torch.multinomial expects 2D probs and returns [B, k] actions = torch.multinomial(probs, num_samples=1).squeeze(-1) return actions.long(), hidden_state
[docs] def get_value( self, obs: torch.Tensor, hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """Get value for a single observation.""" obs = obs.unsqueeze(1) _, values, hidden_state = self.forward(obs, hidden_state) return values.squeeze(1), hidden_state
[docs] def init_hidden( self, batch_size: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """Initialize LSTM hidden states.""" h = torch.zeros(1, batch_size, self.lstm.hidden_size, device=device) c = torch.zeros(1, batch_size, self.lstm.hidden_size, device=device) return (h, c)
[docs] def get_hidden_size(self) -> int: """Get LSTM hidden size.""" return self.lstm.hidden_size
[docs] class RLLoss(nn.Module): """ RL loss functions for DIAMOND. Implements REINFORCE with value baseline and λ-returns. """ def __init__( self, discount_factor: float = 0.985, lambda_returns: float = 0.95, entropy_weight: float = 0.001, ): super().__init__() self.discount_factor = discount_factor self.lambda_returns = lambda_returns self.entropy_weight = entropy_weight
[docs] def compute_lambda_returns( self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor, ) -> torch.Tensor: """ Compute λ-returns. Args: rewards: [B, T] values: [B, T+1] dones: [B, T] Returns: lambda_returns: [B, T] """ B, T = rewards.shape lambda_returns = torch.zeros_like(rewards) returns = values[:, -1] for t in reversed(range(T)): returns = rewards[:, t] + self.discount_factor * ( 1 - dones[:, t].float() ) * ( (1 - self.lambda_returns) * values[:, t + 1] + self.lambda_returns * returns ) lambda_returns[:, t] = returns return lambda_returns
[docs] def policy_loss( self, policy_logits: torch.Tensor, actions: torch.Tensor, lambda_returns: torch.Tensor, values: torch.Tensor, ) -> torch.Tensor: """ Compute policy loss with REINFORCE and entropy regularization. Args: policy_logits: [B, T, A] actions: [B, T] lambda_returns: [B, T] values: [B, T+1] Returns: policy_loss: scalar """ log_probs = F.log_softmax(policy_logits, dim=-1) action_log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) advantages = lambda_returns - values[:, :-1] policy_loss = -(action_log_probs * advantages.detach()).mean() entropy = -(log_probs.exp() * log_probs).sum(dim=-1).mean() entropy_loss = -self.entropy_weight * entropy return policy_loss + entropy_loss
[docs] def value_loss( self, values: torch.Tensor, lambda_returns: torch.Tensor, ) -> torch.Tensor: """Compute value loss (MSE between value and lambda returns).""" target = lambda_returns.detach() value_pred = values[:, :-1].squeeze(-1) return F.mse_loss(value_pred, target)