import numpy as np
from collections import deque
from numpy.random import choice
def _identity(x):
return x
[docs]
class Episode:
"""Records the agent's interaction with the environment for a single episode.
Stores observations, actions, rewards, and terminal flags during a single
trajectory. At termination, converts all lists to numpy arrays for
efficient batch processing.
Attributes:
x (list or np.ndarray): Observations collected during the episode.
u (list or np.ndarray): Actions taken.
r (list or np.ndarray): Rewards received.
t (list or np.ndarray): Terminal flags (0.0 = continue, 1.0 = terminal).
info (dict): Additional episode metadata.
Args:
postprocess_fn (callable, optional): Function to apply to observations
before storing (e.g., normalization). Default: identity function.
Example:
>>> episode = Episode()
>>> episode.append(obs, action, reward, False)
>>> episode.append(obs, action, reward, True)
>>> episode.terminate(final_obs)
>>> print(episode.x.shape) # Now a numpy array
"""
def __init__(self, postprocess_fn=None):
self.x = []
self.u = []
self.t = []
self.r = []
self.postprocess_fn = _identity if postprocess_fn is None else postprocess_fn
self._size = 0
@property
def size(self):
return self._size
[docs]
def append(self, obs, act, reward, terminal):
self._size += 1
self.x.append(self.postprocess_fn(obs.numpy()))
self.u.append(act.cpu().numpy())
self.r.append(reward)
self.t.append(terminal)
[docs]
def terminate(self, obs):
self.x.append(self.postprocess_fn(obs.numpy()))
self.x = np.stack(self.x)
self.u = np.stack(self.u)
self.r = np.stack(self.r)
self.t = np.stack(self.t)
[docs]
class Memory(deque):
"""Episode-based replay memory for PlaNet/RSSM training.
Stores episodes as variable-length trajectories and supports sampling
sub-sequences for training. Implements a ring-buffer style eviction
when capacity is reached.
Features:
- Stores complete episodes as lists of transitions
- Samples contiguous sub-sequences for sequence models
- Supports time-major formatting (time-first) for RNN input
- Memory usage estimation to prevent OOM errors
Args:
size (int, optional): Maximum number of episodes to store. If None,
deque grows without limit (useful for unpickling).
Attributes:
episodes (deque): Collection of Episode objects.
eps_lengths (deque): Length of each episode.
size (property): Total number of transitions across all episodes.
Example:
>>> memory = Memory(size=100)
>>> memory.append([episode1, episode2])
>>> batch, lengths = memory.sample(batch_size=32, tracelen=50)
"""
def __init__(self, size=None):
"""Initialize memory with optional episode capacity.
Args:
size (int, optional): Maximum number of episodes. If None,
creates unbounded deques for pickle compatibility.
"""
maxlen = size if size is not None else None
self.episodes = deque(maxlen=maxlen)
self.eps_lengths = deque(maxlen=maxlen)
if size is not None:
print(f"Creating memory with len {size} episodes.")
@property
def size(self):
return sum(self.eps_lengths)
def _append(self, episode: Episode):
if isinstance(episode, Episode):
self.episodes.append(episode)
self.eps_lengths.append(episode.size)
else:
raise ValueError("can only append <Episode> or list of <Episode>")
[docs]
def append(self, episodes: list[Episode]):
if isinstance(episodes, Episode):
episodes = [episodes]
if isinstance(episodes, list):
for e in episodes:
self._append(e)
else:
raise ValueError("can only append <Episode> or list of <Episode>")
[docs]
def sample(self, batch_size, tracelen=1, time_first=False):
"""Sample random sub-sequences from stored episodes.
Randomly selects episodes and starting positions to create batches
of contiguous sequences for training sequence models.
Args:
batch_size (int): Number of sequences to sample.
tracelen (int): Length of each sequence (default: 1).
time_first (bool): If True, returns tensors with time dimension
first (T, B, ...) instead of batch first (B, T, ...).
Returns:
tuple: (observations, actions, rewards, terminals, lengths)
- observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, ...)
- actions: (batch, tracelen, action_dim) or (tracelen, batch, ...)
- rewards: (batch, tracelen) or (tracelen, batch)
- terminals: (batch, tracelen) or (tracelen, batch)
- lengths: (batch,) original episode lengths for each sample
Raises:
ValueError: If memory is empty or no episodes meet minimum length.
MemoryError: If estimated memory usage exceeds 200 MiB threshold.
"""
if len(self.episodes) == 0:
raise ValueError("Memory is empty; cannot sample.")
valid_idxs = [i for i, L in enumerate(self.eps_lengths) if L >= tracelen]
if len(valid_idxs) == 0:
raise ValueError(
f"No episodes with length >= {tracelen} available to sample."
)
# Quick memory usage estimation to avoid large unexpected allocations.
# Estimate bytes required for stacked observations: N * (tracelen+1) * prod(obs_shape) * bytes_per_elem
try:
sample_ep = self.episodes[valid_idxs[0]]
if isinstance(sample_ep.x, np.ndarray):
obs_shape = sample_ep.x.shape[1:] # (C,H,W)
obs_dtype = sample_ep.x.dtype
else:
obs0 = np.asarray(sample_ep.x[0])
obs_shape = obs0.shape
obs_dtype = obs0.dtype
bytes_per_elem = np.dtype(obs_dtype).itemsize
est_bytes = int(
batch_size * (tracelen + 1) * np.prod(obs_shape) * bytes_per_elem
)
MAX_BYTES = 200 * 1024 * 1024 # 200 MiB threshold (tunable)
if est_bytes > MAX_BYTES:
est_mb = est_bytes / (1024 * 1024)
raise MemoryError(
f"Sampling would allocate ~{est_mb:.1f} MiB for observations "
f"(N={batch_size}, H={tracelen}, obs_shape={obs_shape}, dtype={obs_dtype}). "
"Reduce batch_size or trace length (H) or downsample images."
)
except Exception:
# If anything goes wrong in estimation, continue and let later code raise the original error.
pass
episode_idx = choice(valid_idxs, batch_size)
init_st_idx = []
for i in episode_idx:
max_start = self.eps_lengths[i] - tracelen + 1
if max_start <= 0:
init_st_idx.append(0)
else:
init_st_idx.append(choice(max_start))
x, u, r, t = [], [], [], []
for n, (i, s) in enumerate(zip(episode_idx, init_st_idx)):
ep = self.episodes[i]
# ensure arrays (support episodes that weren't .terminate()'d)
x_arr = np.stack(ep.x) if not isinstance(ep.x, np.ndarray) else ep.x
# Normalize action arrays to have shape (T, action_dim)
u_arr = np.stack(ep.u) if not isinstance(ep.u, np.ndarray) else ep.u
u_arr = np.asarray(u_arr)
if u_arr.ndim == 1:
u_arr = u_arr[:, None] # (T,) -> (T,1)
# Ensure rewards and terminals are 1D arrays of length T
r_arr = np.stack(ep.r) if not isinstance(ep.r, np.ndarray) else ep.r
r_arr = np.asarray(r_arr).reshape(-1)
t_arr = np.stack(ep.t) if not isinstance(ep.t, np.ndarray) else ep.t
t_arr = np.asarray(t_arr).reshape(-1)
x.append(x_arr[s : s + tracelen + 1])
u.append(u_arr[s : s + tracelen])
r.append(r_arr[s : s + tracelen])
t.append(t_arr[s : s + tracelen])
try:
if tracelen == 1:
rets = [np.stack(x)] + [np.stack(i)[:, 0] for i in (u, r, t)]
else:
rets = [np.stack(i) for i in (x, u, r, t)]
except ValueError as exc:
def shapes(lst):
return [getattr(a, "shape", np.asarray(a).shape) for a in lst]
info = {
"x_shapes": shapes(x),
"u_shapes": shapes(u),
"r_shapes": shapes(r),
"t_shapes": shapes(t),
"episode_idx": episode_idx.tolist(),
"start_idx": init_st_idx,
"eps_lengths_sampled": [self.eps_lengths[i] for i in episode_idx],
}
raise ValueError(
f"Failed to stack sampled segments; shapes info: {info}"
) from exc
if time_first:
rets = [a.swapaxes(1, 0) for a in rets]
lengths = np.array([self.eps_lengths[i] for i in episode_idx])
return rets, lengths