Source code for world_models.datasets.diamond_dataset

import numpy as np
from typing import Dict, List, Tuple
import torch


[docs] class ReplayBuffer: """ Replay buffer for storing environment interactions. Stores (observation, action, reward, done, next_observation) tuples. """ def __init__( self, capacity: int = 1000, obs_shape: Tuple[int, int, int] = (64, 64, 3), action_dim: int = 1, device: str = "cuda" if torch.cuda.is_available() else "cpu", ): self.capacity = capacity self.obs_shape = obs_shape self.action_dim = action_dim self.device = device self.observations = np.zeros((capacity,) + obs_shape, dtype=np.uint8) self.actions = np.zeros((capacity, action_dim), dtype=np.int64) self.rewards = np.zeros((capacity,), dtype=np.float32) self.dones = np.zeros((capacity,), dtype=np.bool_) self.next_observations = np.zeros((capacity,) + obs_shape, dtype=np.uint8) self.position = 0 self.size = 0
[docs] def add( self, obs: np.ndarray, action: int, reward: float, done: bool, next_obs: np.ndarray, ): """Add a transition to the buffer.""" self.observations[self.position] = obs self.actions[self.position] = action self.rewards[self.position] = reward self.dones[self.position] = done self.next_observations[self.position] = next_obs self.position = (self.position + 1) % self.capacity self.size = min(self.size + 1, self.capacity)
[docs] def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: """Sample a random batch of transitions.""" indices = np.random.randint(0, self.size, size=batch_size) obs = ( torch.from_numpy(self.observations[indices]).float().to(self.device) / 255.0 ) # observations stored as H,W,C -> convert to C,H,W obs = obs.permute(0, 3, 1, 2) next_obs = ( torch.from_numpy(self.next_observations[indices]).float().to(self.device) / 255.0 ) next_obs = next_obs.permute(0, 3, 1, 2) actions = torch.from_numpy(self.actions[indices]).long().to(self.device) if actions.ndim > 1 and actions.shape[-1] == 1: actions = actions.squeeze(-1) return { "obs": obs, "actions": actions, "rewards": torch.from_numpy(self.rewards[indices]).float().to(self.device), "dones": torch.from_numpy(self.dones[indices]).bool().to(self.device), "next_obs": next_obs, }
[docs] def sample_sequence( self, batch_size: int, sequence_length: int, burn_in: int = 0, ) -> Dict[str, torch.Tensor]: """ Sample a sequence of transitions for training. Args: batch_size: Number of sequences to sample sequence_length: Total sequence length (burn_in + horizon) burn_in: Number of initial frames to use for conditioning Returns: Dictionary with tensors of shape (batch_size, sequence_length, ...) """ max_start = self.size - sequence_length - 1 if max_start < 0: max_start = 0 start_indices = np.random.randint(0, max_start + 1, size=batch_size) obs_seq = [] action_seq = [] reward_seq = [] done_seq = [] next_obs_seq = [] for i in range(batch_size): start = start_indices[i] indices = np.arange(start, start + sequence_length + 1) obs_seq.append(self.observations[indices[:-1]]) action_seq.append(self.actions[indices[:-1]]) reward_seq.append(self.rewards[indices[:-1]]) done_seq.append(self.dones[indices[:-1]]) next_obs_seq.append(self.next_observations[indices[:-1]]) obs = torch.from_numpy(np.stack(obs_seq)).float().to(self.device) / 255.0 # obs: (B, T, H, W, C) -> (B, T, C, H, W) obs = obs.permute(0, 1, 4, 2, 3) next_obs = ( torch.from_numpy(np.stack(next_obs_seq)).float().to(self.device) / 255.0 ) next_obs = next_obs.permute(0, 1, 4, 2, 3) actions = torch.from_numpy(np.stack(action_seq)).long().to(self.device) if actions.ndim > 2 and actions.shape[-1] == 1: actions = actions.squeeze(-1) return { "obs": obs, "actions": actions, "rewards": torch.from_numpy(np.stack(reward_seq)).float().to(self.device), "dones": torch.from_numpy(np.stack(done_seq)).bool().to(self.device), "next_obs": next_obs, }
def __len__(self): return self.size
[docs] def is_ready(self, min_size: int) -> bool: """Check if buffer has enough samples.""" return self.size >= min_size
[docs] def state_dict(self) -> dict: """Return a serializable state dict for checkpointing. Contains numpy arrays and scalar metadata so it can be saved with torch.save or numpy.save. """ return { "observations": self.observations, "actions": self.actions, "rewards": self.rewards, "dones": self.dones, "next_observations": self.next_observations, "position": int(self.position), "size": int(self.size), "capacity": int(self.capacity), }
[docs] def load_state_dict(self, state: dict): """Load state previously produced by `state_dict()`. This will resize internal arrays if the saved capacity differs from the current buffer capacity. """ obs = state["observations"] actions = state["actions"] rewards = state["rewards"] dones = state["dones"] next_obs = state["next_observations"] pos = int(state.get("position", 0)) size = int(state.get("size", 0)) # allocate arrays with saved capacity shapes self.capacity = int(state.get("capacity", obs.shape[0])) self.observations = np.zeros((self.capacity,) + self.obs_shape, dtype=np.uint8) self.next_observations = np.zeros( (self.capacity,) + self.obs_shape, dtype=np.uint8 ) self.actions = np.zeros((self.capacity, self.action_dim), dtype=np.int64) self.rewards = np.zeros((self.capacity,), dtype=np.float32) self.dones = np.zeros((self.capacity,), dtype=np.bool_) # copy available data up to saved size n = min(size, obs.shape[0], self.capacity) if n > 0: self.observations[:n] = obs[:n] self.next_observations[:n] = next_obs[:n] self.actions[:n] = actions[:n] self.rewards[:n] = rewards[:n] self.dones[:n] = dones[:n] self.position = int(pos) % self.capacity if self.capacity > 0 else 0 self.size = min(int(size), self.capacity)
[docs] class SequenceDataset(torch.utils.data.Dataset): """ PyTorch Dataset for sampling sequences from the replay buffer. Used for training the diffusion world model. """ def __init__( self, replay_buffer: ReplayBuffer, sequence_length: int = 5, # L (conditioning) + 1 (next frame) burn_in: int = 4, ): self.replay_buffer = replay_buffer self.sequence_length = sequence_length self.burn_in = burn_in def __len__(self): return max(0, self.replay_buffer.size - self.sequence_length) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a sequence starting at idx.""" indices = np.arange(idx, idx + self.sequence_length + 1) # keep numpy arrays separate to avoid mypy inferring ndarray types obs_seq_np = self.replay_buffer.observations[indices[:-1]] action_seq_np = self.replay_buffer.actions[indices[:-1]] reward_seq_np = self.replay_buffer.rewards[indices[:-1]] done_seq_np = self.replay_buffer.dones[indices[:-1]] next_obs_np = self.replay_buffer.next_observations[indices[-1]] # convert and move to device device = self.replay_buffer.device obs_seq = torch.from_numpy(obs_seq_np).float().to(device) / 255.0 # (T, H, W, C) -> (T, C, H, W) if obs_seq.ndim == 4: obs_seq = obs_seq.permute(0, 3, 1, 2) next_obs = torch.from_numpy(next_obs_np).float().to(device) / 255.0 # ensure next_obs is (C, H, W) if next_obs.ndim == 3: next_obs = next_obs.permute(2, 0, 1) # (H,W,C) -> (C,H,W) action_seq = torch.from_numpy(action_seq_np).long().to(device) if action_seq.ndim > 1 and action_seq.shape[-1] == 1: action_seq = action_seq.squeeze(-1) rewards = torch.from_numpy(reward_seq_np).float().to(device) dones = torch.from_numpy(done_seq_np).bool().to(device) return { "obs_seq": obs_seq, "action_seq": action_seq, "actions": action_seq, # duplicate key for compatibility "rewards": rewards, "dones": dones, "next_obs": next_obs, }
[docs] def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """Collate function for the dataloader.""" obs_seq = torch.stack([item["obs_seq"] for item in batch]) action_seq = torch.stack([item["action_seq"] for item in batch]) next_obs = torch.stack([item["next_obs"] for item in batch]) return { "obs_seq": obs_seq, "action_seq": action_seq, "next_obs": next_obs, }