Source code for world_models.models.diffusion.reward_termination

import torch
import torch.nn as nn
from typing import Tuple, Optional


[docs] class ConvBlock(nn.Module): """Convolutional block with adaptive group normalization.""" def __init__( self, in_channels: int, out_channels: int, cond_dim: int, stride: int = 2, ): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 4, stride, padding=1) self.norm = nn.GroupNorm(8, out_channels) self.act = nn.SiLU() self.cond_embed = nn.Linear(cond_dim, out_channels * 2)
[docs] def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.norm(x) scale, bias = self.cond_embed(cond).chunk(2, dim=-1) scale = scale.unsqueeze(-1).unsqueeze(-1) bias = bias.unsqueeze(-1).unsqueeze(-1) return self.act(x * (1 + scale) + bias)
[docs] class RewardTerminationModel(nn.Module): """ Reward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications. Args: obs_channels: Number of observation channels (3 for RGB) action_dim: Number of possible actions channels: List of channel sizes for conv blocks lstm_dim: LSTM hidden dimension cond_dim: Conditioning dimension for adaptive norm """ def __init__( self, obs_channels: int = 3, action_dim: int = 18, channels: Tuple[int, ...] = (32, 32, 32, 32), lstm_dim: int = 512, cond_dim: int = 128, ): super().__init__() self.obs_channels = obs_channels self.action_dim = action_dim self.lstm_dim = lstm_dim self.action_embed = nn.Embedding(action_dim, cond_dim) self.conv_blocks = nn.ModuleList() in_ch = obs_channels for i, out_ch in enumerate(channels): self.conv_blocks.append(ConvBlock(in_ch, out_ch, cond_dim, stride=2)) in_ch = out_ch self.lstm = nn.LSTM( input_size=channels[-1], hidden_size=lstm_dim, num_layers=1, batch_first=True, ) self.reward_head = nn.Linear(lstm_dim, 3) self.termination_head = nn.Linear(lstm_dim, 2)
[docs] def forward( self, obs: torch.Tensor, actions: torch.Tensor, hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Forward pass of reward/termination model. Args: obs: Observations [B, T, C, H, W] actions: Actions [B, T] hidden_state: Optional (h, c) hidden states Returns: reward_logits: Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states """ B, T, C, H, W = obs.shape obs_flat = obs.view(B * T, C, H, W) actions_flat = actions.view(B * T) action_emb = self.action_embed(actions_flat) h = obs_flat for conv_block in self.conv_blocks: h = conv_block(h, action_emb) h = h.mean(dim=[2, 3]) h = h.view(B, T, -1) if hidden_state is None: lstm_out, hidden_state = self.lstm(h) else: lstm_out, hidden_state = self.lstm(h, hidden_state) reward_logits = self.reward_head(lstm_out) termination_logits = self.termination_head(lstm_out) return reward_logits, termination_logits, hidden_state
[docs] def predict( self, obs: torch.Tensor, actions: torch.Tensor, hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: """ Predict reward and termination for a single step. Args: obs: Single observation [B, C, H, W] actions: Single action [B] hidden_state: Optional (h, c) hidden states Returns: reward: Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states """ obs = obs.unsqueeze(1) actions = actions.unsqueeze(1) reward_logits, term_logits, hidden_state = self.forward( obs, actions, hidden_state ) reward = reward_logits.argmax(dim=-1) - 1 terminated = term_logits.argmax(dim=-1).bool() return reward.squeeze(-1).float(), terminated.squeeze(-1), hidden_state
[docs] def init_hidden( self, batch_size: int, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """Initialize LSTM hidden states.""" h = torch.zeros(1, batch_size, self.lstm_dim, device=device) c = torch.zeros(1, batch_size, self.lstm_dim, device=device) return (h, c)
[docs] class RewardTerminationLoss(nn.Module): """Loss function for reward and termination prediction.""" def __init__(self): super().__init__() self.reward_criterion = nn.CrossEntropyLoss(reduction="mean") self.termination_criterion = nn.CrossEntropyLoss(reduction="mean")
[docs] def forward( self, reward_logits: torch.Tensor, termination_logits: torch.Tensor, rewards: torch.Tensor, terminated: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute loss for reward and termination predictions. Args: reward_logits: [B, T, 3] termination_logits: [B, T, 2] rewards: Rewards as class indices [B, T] (values -1, 0, 1 mapped to 0, 1, 2) terminated: Termination flags [B, T] Returns: total_loss, reward_loss, termination_loss """ reward_targets = (rewards + 1).long() # use reshape to avoid issues when tensors are non-contiguous reward_loss = self.reward_criterion( reward_logits.reshape(-1, 3), reward_targets.view(-1) ) termination_loss = self.termination_criterion( termination_logits.reshape(-1, 2), terminated.long().view(-1) ) total_loss = reward_loss + termination_loss return total_loss, reward_loss, termination_loss