import numpy as np
[docs]
class ReplayBuffer:
"""Fixed-size replay buffer for Dreamer image observations and transitions.
Stores `(observation image, action, reward, terminal)` tuples and supports
sampling contiguous sequences used by world-model unroll training.
"""
def __init__(self, size, obs_shape, action_size, seq_len, batch_size):
self.size = size
self.obs_shape = obs_shape
self.action_size = action_size
self.seq_len = seq_len
self.batch_size = batch_size
self.idx = 0
self.full = False
self.observations = np.empty((size, *obs_shape), dtype=np.uint8)
self.actions = np.empty((size, action_size), dtype=np.float32)
self.rewards = np.empty((size,), dtype=np.float32)
self.terminals = np.empty((size,), dtype=np.float32)
self.steps, self.episodes = 0, 0
[docs]
def add(self, obs, ac, rew, done):
self.observations[self.idx] = obs["image"]
self.actions[self.idx] = ac
self.rewards[self.idx] = rew
self.terminals[self.idx] = done
self.idx = (self.idx + 1) % self.size
self.full = self.full or self.idx == 0
self.steps += 1
self.episodes = self.episodes + (1 if done else 0)
def _sample_idx(self, L):
valid_idx = False
while not valid_idx:
idx = np.random.randint(0, self.size if self.full else self.idx - L)
idxs = np.arange(idx, idx + L) % self.size
valid_idx = self.idx not in idxs[1:]
return idxs
def _retrieve_batch(self, idxs, n, L):
vec_idxs = idxs.transpose().reshape(-1) # Unroll indices
observations = self.observations[vec_idxs]
return (
observations.reshape(L, n, *observations.shape[1:]),
self.actions[vec_idxs].reshape(L, n, -1),
self.rewards[vec_idxs].reshape(L, n),
self.terminals[vec_idxs].reshape(L, n),
)
[docs]
def sample(self):
n = self.batch_size
L = self.seq_len
obs, acs, rews, terms = self._retrieve_batch(
np.asarray([self._sample_idx(L) for _ in range(n)]), n, L
)
return obs, acs, rews, terms