import numpy as np
from typing import Tuple, Deque, Any
from collections import deque, namedtuple
[docs]
class ReplayBuffer:
"""Fixed-size replay buffer for Dreamer with image observations and transitions.
Stores (observation, action, reward, terminal) tuples in a ring buffer and
supports sampling contiguous sequences for world-model training.
Key Features:
- Ring buffer with fixed capacity (FIFO eviction when full)
- Stores raw uint8 images to save memory
- Samples sequences (not single transitions) for temporal modeling
- Validates sampled sequences don't span episode boundaries
Memory Layout:
- observations: (capacity, C, H, W) uint8 images
- actions: (capacity, action_dim) float32
- rewards: (capacity,) float32
- terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)
Sampling Process:
1. Random start index (avoiding episode boundaries)
2. Collect sequence of length seq_len with wraparound
3. Validate no terminal in middle of sequence
4. Return batch of sequences
Usage with Dreamer:
buffer = ReplayBuffer(
size=100000, # Max transitions to store
obs_shape=(3, 64, 64), # RGB images
action_size=6, # Continuous action dim
seq_len=50, # Sequence length for training
batch_size=50 # Parallel sequences per batch
)
# Add transitions during interaction
buffer.add(obs, action, reward, done)
# Sample batch for world model training
obs_batch, action_batch, reward_batch, term_batch = buffer.sample()
# Shapes: (seq_len, batch, C, H, W), (seq_len, batch, action_dim), etc.
Memory Efficiency:
- Uses uint8 for images (1 byte per pixel vs 4 for float32)
- Sequences share observations (overlapping windows)
- Configurable capacity based on available system memory
Note:
The buffer stores observations as {"image": ...} dicts but returns
just the image arrays for training efficiency.
"""
def __init__(
self,
size: int,
obs_shape: Tuple[int, ...],
action_size: int,
seq_len: int,
batch_size: int,
):
"""Initialize replay buffer with fixed capacity.
Args:
size (int): Maximum number of transitions to store.
obs_shape (tuple): Shape of observations (C, H, W).
action_size (int): Dimension of action vectors.
seq_len (int): Length of sequences to sample.
batch_size (int): Number of sequences per batch.
"""
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,), dtype=np.float32)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
[docs]
def add(self, obs: dict, ac: np.ndarray, rew: float, done: float) -> None:
"""Add a transition to the buffer.
Args:
obs: Observation dict with 'image' key containing the observation
ac: Action taken, shape (action_size,)
rew: Reward received, scalar
done: Terminal flag, 1.0 if episode ended, 0.0 otherwise
"""
self.observations[self.idx] = obs["image"]
self.actions[self.idx] = ac
self.rewards[self.idx] = rew
self.terminals[self.idx] = done
self.idx = (self.idx + 1) % self.size
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
def _sample_idx(self, L: int) -> np.ndarray:
"""Sample valid starting indices for a sequence of length L.
Ensures the sampled sequence doesn't span episode boundaries
by checking that no terminal flags appear in the sequence.
Args:
L: Sequence length to validate
Returns:
Array of L indices into the buffer
"""
# ensure idxs is defined for static checkers
idxs: np.ndarray = np.array([], dtype=int)
valid_idx = False
# compute a safe upper bound for randint (must be > 0)
max_end = self.size if self.full else self.idx
upper = max_end - L
if upper <= 0:
# not enough data yet; return a default contiguous window starting at 0
return np.arange(0, L) % max(1, self.size)
while not valid_idx:
idx = np.random.randint(0, upper)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = self.idx not in idxs[1:]
return idxs
def _retrieve_batch(self, idxs: np.ndarray, n: int, L: int):
"""Retrieve batch of sequences given indices.
Args:
idxs: Starting indices for n sequences, shape (n, L)
n: Number of sequences (batch size)
L: Sequence length
Returns:
observations: (L, n, C, H, W)
actions: (L, n, action_dim)
rewards: (L, n)
terminals: (L, n)
"""
vec_idxs = idxs.transpose().reshape(-1)
observations = self.observations[vec_idxs]
return (
observations.reshape(L, n, *observations.shape[1:]),
self.actions[vec_idxs].reshape(L, n, -1),
self.rewards[vec_idxs].reshape(L, n),
self.terminals[vec_idxs].reshape(L, n),
)
[docs]
def sample(self):
"""Sample a batch of sequences for training.
Returns:
tuple: (observations, actions, rewards, terminals)
- observations: (seq_len, batch, C, H, W)
- actions: (seq_len, batch, action_dim)
- rewards: (seq_len, batch)
- terminals: (seq_len, batch)
"""
n = self.batch_size
L = self.seq_len
obs, acs, rews, terms = self._retrieve_batch(
np.asarray([self._sample_idx(L) for _ in range(n)]), n, L
)
return obs, acs, rews, terms
[docs]
class Memory:
"""Simple deque-based memory for storing transitions.
Used by PlaNet for online planning. Stores recent transitions and
provides random sampling for policy updates.
Args:
capacity: Maximum number of transitions to store
Usage:
memory = Memory(capacity=10000)
memory.append(obs, action, reward, done, info)
batch = random.sample(memory, batch_size=32)
"""
def __init__(self, capacity: int = 10000) -> None:
# typed deque for stored transitions
self.memory: Deque[tuple[Any, ...]] = deque(maxlen=capacity)
[docs]
def append(self, *args) -> None:
"""Append a transition to memory.
Args:
*args: Variable length argument list containing transition data.
Typically (observation, action, reward, done, info).
"""
self.memory.append(args)
[docs]
def sample(self, batch_size: int):
"""Sample random batch of transitions from memory.
Args:
batch_size (int): Number of transitions to sample.
Returns:
list: List of sampled transitions.
"""
from random import sample
return sample(self.memory, batch_size)
def __len__(self) -> int:
return len(self.memory)
[docs]
class Episode:
"""Stores a single episode for PlaNet's imagination and planning.
An episode is a sequence of (observation, action, reward) tuples
collected during environment interaction. Episodes are used for
computing returns and training value functions.
Args:
obs: Initial observation
action: First action (optional)
reward: Initial reward (optional)
info: Additional info dict (optional)
Usage:
episode = Episode(obs, info=info)
episode.append(action, obs, reward, done, info)
episodes = [episode for _ in range(num_episodes)]
# Use with Planet agent for planning
imag_state, imag_reward, imag_action = planet.imagine(episodes)
"""
_fields = ["observation", "action", "reward", "terminal", "info"]
def __init__(self, observation, action=None, reward=None, terminal=None, info=None):
self.observation = observation
if action is not None:
self.action = [action]
else:
self.action = []
if reward is not None:
self.reward = [reward]
else:
self.reward = []
if terminal is not None:
self.terminal = [terminal]
else:
self.terminal = []
self.info = info if info is not None else {}
[docs]
def append(self, action, observation, reward, terminal, info=None) -> None:
self.action.append(action)
self.observation.append(observation)
self.reward.append(reward)
self.terminal.append(terminal)
if info is not None:
for k, v in info.items():
if k not in self.info:
self.info[k] = []
self.info[k].append(v)
def __len__(self) -> int:
return len(self.observation)