Source code for world_models.memory.planet_memory

import numpy as np

from collections import deque
from numpy.random import choice


def _identity(x):
    return x


[docs] class Episode: """Records the agent's interaction with the environment for a single episode. At termination, it converts all the data to Numpy arrays. """ def __init__(self, postprocess_fn=None): self.x = [] self.u = [] self.t = [] self.r = [] self.postprocess_fn = _identity if postprocess_fn is None else postprocess_fn self._size = 0 @property def size(self): return self._size
[docs] def append(self, obs, act, reward, terminal): self._size += 1 self.x.append(self.postprocess_fn(obs.numpy())) self.u.append(act.cpu().numpy()) self.r.append(reward) self.t.append(terminal)
[docs] def terminate(self, obs): self.x.append(self.postprocess_fn(obs.numpy())) self.x = np.stack(self.x) self.u = np.stack(self.u) self.r = np.stack(self.r) self.t = np.stack(self.t)
[docs] class Memory(deque): """Episode-based replay memory for PlaNet/RSSM training. Episodes are stored as variable-length trajectories and sampled as sub-sequences with optional time-major formatting for sequence models. """ def __init__(self, size=None): """Maintains a FIFO list of `size` number of episodes. If size is None (e.g. during unpickling or loading legacy files), create deques without a fixed maxlen so pickle can restore state. """ maxlen = size if size is not None else None self.episodes = deque(maxlen=maxlen) self.eps_lengths = deque(maxlen=maxlen) if size is not None: print(f"Creating memory with len {size} episodes.") @property def size(self): return sum(self.eps_lengths) def _append(self, episode: Episode): if isinstance(episode, Episode): self.episodes.append(episode) self.eps_lengths.append(episode.size) else: raise ValueError("can only append <Episode> or list of <Episode>")
[docs] def append(self, episodes: list[Episode]): if isinstance(episodes, Episode): episodes = [episodes] if isinstance(episodes, list): for e in episodes: self._append(e) else: raise ValueError("can only append <Episode> or list of <Episode>")
[docs] def sample(self, batch_size, tracelen=1, time_first=False): if len(self.episodes) == 0: raise ValueError("Memory is empty; cannot sample.") valid_idxs = [i for i, L in enumerate(self.eps_lengths) if L >= tracelen] if len(valid_idxs) == 0: raise ValueError( f"No episodes with length >= {tracelen} available to sample." ) # Quick memory usage estimation to avoid large unexpected allocations. # Estimate bytes required for stacked observations: N * (tracelen+1) * prod(obs_shape) * bytes_per_elem try: sample_ep = self.episodes[valid_idxs[0]] if isinstance(sample_ep.x, np.ndarray): obs_shape = sample_ep.x.shape[1:] # (C,H,W) obs_dtype = sample_ep.x.dtype else: obs0 = np.asarray(sample_ep.x[0]) obs_shape = obs0.shape obs_dtype = obs0.dtype bytes_per_elem = np.dtype(obs_dtype).itemsize est_bytes = int( batch_size * (tracelen + 1) * np.prod(obs_shape) * bytes_per_elem ) MAX_BYTES = 200 * 1024 * 1024 # 200 MiB threshold (tunable) if est_bytes > MAX_BYTES: est_mb = est_bytes / (1024 * 1024) raise MemoryError( f"Sampling would allocate ~{est_mb:.1f} MiB for observations " f"(N={batch_size}, H={tracelen}, obs_shape={obs_shape}, dtype={obs_dtype}). " "Reduce batch_size or trace length (H) or downsample images." ) except Exception: # If anything goes wrong in estimation, continue and let later code raise the original error. pass episode_idx = choice(valid_idxs, batch_size) init_st_idx = [] for i in episode_idx: max_start = self.eps_lengths[i] - tracelen + 1 if max_start <= 0: init_st_idx.append(0) else: init_st_idx.append(choice(max_start)) x, u, r, t = [], [], [], [] for n, (i, s) in enumerate(zip(episode_idx, init_st_idx)): ep = self.episodes[i] # ensure arrays (support episodes that weren't .terminate()'d) x_arr = np.stack(ep.x) if not isinstance(ep.x, np.ndarray) else ep.x # Normalize action arrays to have shape (T, action_dim) u_arr = np.stack(ep.u) if not isinstance(ep.u, np.ndarray) else ep.u u_arr = np.asarray(u_arr) if u_arr.ndim == 1: u_arr = u_arr[:, None] # (T,) -> (T,1) # Ensure rewards and terminals are 1D arrays of length T r_arr = np.stack(ep.r) if not isinstance(ep.r, np.ndarray) else ep.r r_arr = np.asarray(r_arr).reshape(-1) t_arr = np.stack(ep.t) if not isinstance(ep.t, np.ndarray) else ep.t t_arr = np.asarray(t_arr).reshape(-1) x.append(x_arr[s : s + tracelen + 1]) u.append(u_arr[s : s + tracelen]) r.append(r_arr[s : s + tracelen]) t.append(t_arr[s : s + tracelen]) try: if tracelen == 1: rets = [np.stack(x)] + [np.stack(i)[:, 0] for i in (u, r, t)] else: rets = [np.stack(i) for i in (x, u, r, t)] except ValueError as exc: def shapes(lst): return [getattr(a, "shape", np.asarray(a).shape) for a in lst] info = { "x_shapes": shapes(x), "u_shapes": shapes(u), "r_shapes": shapes(r), "t_shapes": shapes(t), "episode_idx": episode_idx.tolist(), "start_idx": init_st_idx, "eps_lengths_sampled": [self.eps_lengths[i] for i in episode_idx], } raise ValueError( f"Failed to stack sampled segments; shapes info: {info}" ) from exc if time_first: rets = [a.swapaxes(1, 0) for a in rets] lengths = np.array([self.eps_lengths[i] for i in episode_idx]) return rets, lengths