import numpy as np
from typing import Tuple, Optional, List
[docs]
class IRISReplayBuffer:
"""Replay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.
Stores (observation, action, reward, terminal) tuples in a ring buffer
and supports sampling contiguous sequences for world model training.
Features:
- Ring buffer with fixed capacity (FIFO eviction when full)
- Stores uint8 images for memory efficiency
- Samples sequences with validation to avoid episode boundaries
- Supports sequence sampling for temporal learning
Memory Layout:
- observations: (capacity, C, H, W) uint8
- actions: (capacity, action_size) float32
- rewards: (capacity,) float32
- terminals: (capacity,) float32
Args:
size (int): Maximum number of transitions to store.
obs_shape (tuple): Shape of observations as (C, H, W).
action_size (int): Dimension of actions.
seq_len (int): Length of sequences to sample (default: 20).
batch_size (int): Number of sequences per batch (default: 64).
Attributes:
size (int): Buffer capacity.
obs_shape (tuple): Observation shape.
action_size (int): Action dimension.
seq_len (int): Sequence length.
batch_size (int): Batch size.
steps (int): Total transitions added.
episodes (int): Number of episode terminations observed.
"""
def __init__(
self,
size: int,
obs_shape: Tuple[int, int, int],
action_size: int,
seq_len: int = 20,
batch_size: int = 64,
):
self.size = size
self.obs_shape = obs_shape # (C, H, W)
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.steps = 0
self.episodes = 0
self.observations = np.zeros((size, *obs_shape), dtype=np.uint8)
self.actions = np.zeros((size, action_size), dtype=np.float32)
self.rewards = np.zeros((size,), dtype=np.float32)
self.terminals = np.zeros((size,), dtype=np.float32)
[docs]
def add(self, obs: np.ndarray, action: np.ndarray, reward: float, terminal: bool):
"""Add a transition to the buffer.
Args:
obs: Observation array with shape (C, H, W).
action: Action array with shape (action_size,).
reward: Scalar reward value.
terminal: Boolean indicating if episode terminated.
"""
self.observations[self.idx] = obs
self.actions[self.idx] = action
self.rewards[self.idx] = reward
self.terminals[self.idx] = float(terminal)
self.idx = (self.idx + 1) % self.size
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes += 1 if terminal else 0
[docs]
def sample_sequence(
self, seq_len: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Sample a batch of sequences for world model training.
Returns:
observations: (batch_size, seq_len+1, C, H, W)
actions: (batch_size, seq_len, action_size)
rewards: (batch_size, seq_len)
terminals: (batch_size, seq_len)
"""
if seq_len is None:
seq_len = self.seq_len
batch_size = self.batch_size
L = seq_len
# Sample starting indices
idxs = self._sample_idxs(L, batch_size)
# Build sequences
observations = []
actions = []
rewards = []
terminals = []
for idx in idxs:
# Get sequence of observations (L+1 for predicting next frame)
obs_seq = self.observations[idx : idx + L + 1]
act_seq = self.actions[idx : idx + L]
rew_seq = self.rewards[idx : idx + L]
term_seq = self.terminals[idx : idx + L]
# Handle wrapping
if len(obs_seq) < L + 1:
# Pad by wrapping around
obs_seq = np.concatenate(
[obs_seq, self.observations[: L + 1 - len(obs_seq)]]
)
act_seq = np.concatenate([act_seq, self.actions[: L - len(act_seq)]])
rew_seq = np.concatenate([rew_seq, self.rewards[: L - len(rew_seq)]])
term_seq = np.concatenate(
[term_seq, self.terminals[: L - len(term_seq)]]
)
observations.append(obs_seq)
actions.append(act_seq)
rewards.append(rew_seq)
terminals.append(term_seq)
return (
np.stack(observations),
np.stack(actions),
np.stack(rewards),
np.stack(terminals),
)
def _sample_idxs(self, L: int, n: int) -> np.ndarray:
"""Sample n valid starting indices for sequences of length L."""
valid_start_range = self.size if self.full else self.idx - L
if valid_start_range <= 0:
return np.zeros(n, dtype=int)
idxs = np.random.randint(0, valid_start_range, size=n)
# Ensure we don't wrap around terminal states in the middle
for i in range(n):
# Check if any terminal in the sequence (excluding last step which is the target)
for j in range(L - 1):
if self.terminals[(idxs[i] + j) % self.size] > 0:
# Terminal found, need to resample
idxs[i] = np.random.randint(0, valid_start_range)
break
return idxs
[docs]
def sample_single(self) -> Tuple[np.ndarray, np.ndarray, float, float]:
"""Sample a single transition for online updates."""
idx = np.random.randint(0, self.size if self.full else self.idx)
return (
self.observations[idx],
self.actions[idx],
self.rewards[idx],
self.terminals[idx],
)
def __len__(self):
return self.size if self.full else self.idx
@property
def buffer_capacity(self):
"""Returns the total capacity of the buffer."""
return self.size
[docs]
class IRISOnPolicyBuffer:
"""On-policy buffer for collecting trajectories during environment interaction.
Used to store the current episode data before adding to the main replay buffer.
Unlike the main replay buffer, this collects trajectories in a list-based
structure that's cleared after each episode.
Useful for:
- Collecting complete episode trajectories
- Storing data before batch processing
- Temporary storage during environment interaction
Args:
max_steps (int): Maximum number of steps to store (default: 1000).
Attributes:
max_steps (int): Maximum buffer capacity.
observations (list): List of observations.
actions (list): List of actions.
rewards (list): List of rewards.
terminals (list): List of terminal flags.
"""
def __init__(self, max_steps: int = 1000):
self.max_steps = max_steps
# Typed lists to satisfy static type checkers
self.observations: List[np.ndarray] = []
self.actions: List[np.ndarray] = []
self.rewards: List[float] = []
self.terminals: List[float] = []
[docs]
def add(self, obs: np.ndarray, action: np.ndarray, reward: float, terminal: bool):
self.observations.append(obs)
self.actions.append(action)
self.rewards.append(reward)
self.terminals.append(float(terminal))
[docs]
def clear(self):
self.observations = []
self.actions = []
self.rewards = []
self.terminals = []
[docs]
def get_arrays(self):
return (
np.array(self.observations),
np.array(self.actions),
np.array(self.rewards),
np.array(self.terminals),
)
def __len__(self):
return len(self.observations)