import os
import cv2
import gym
import torch
import pickle
import pathlib
import numpy as np
if not hasattr(np, "bool8"):
np.bool8 = np.bool_
import plotly
from plotly.graph_objs import Scatter, Line
from collections import defaultdict
from world_models.memory.planet_memory import Memory
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid, save_image
import torch.nn.functional as F
import yaml
import collections
import collections.abc
try:
from attrdict import AttrDict
except ImportError:
[docs]
class AttrDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
def __delattr__(self, name):
del self[name]
for type_name in collections.abc.__all__:
setattr(collections, type_name, getattr(collections.abc, type_name))
[docs]
def load_yml_config(path):
with open(path) as fileStream:
loaded = yaml.safe_load(fileStream)
keys = list(loaded.keys())
dictionary = None
for i, key in enumerate(keys):
if i == 0:
dictionary = AttrDict({key: loaded[key]})
else:
dictionary += AttrDict({key: loaded[key]})
return dictionary
[docs]
def to_tensor_obs(image):
"""
Converts the input np img to channel first 64x64 dim torch img.
"""
image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_LINEAR)
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
return image
[docs]
def postprocess_img(image, depth):
"""
Postprocess an image observation for storage.
From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])
"""
image = np.floor((image + 0.5) * 2**depth)
return np.clip(image * 2 ** (8 - depth), 0, 2**8 - 1).astype(np.uint8)
[docs]
def preprocess_img(image, depth):
"""
Preprocesses an observation inplace.
From float32 Tensor [0, 255] to [-0.5, 0.5]
Also adds some noise to the observations !!
"""
image.div_(2 ** (8 - depth)).floor_().div_(2**depth).sub_(0.5)
image.add_(torch.randn_like(image).div_(2**depth)).clamp_(-0.5, 0.5)
[docs]
def bottle(func, *tensors):
"""
Evaluates a func that operates in N x D with inputs of shape N x T x D
"""
n, t = tensors[0].shape[:2]
inputs = [x.reshape(n * t, *x.shape[2:]) for x in tensors]
out = func(*inputs)
return out.view(n, t, *out.shape[1:])
[docs]
def get_combined_params(*models):
"""
Returns the combine parameter list of all the models given as input.
"""
params = []
for model in models:
params.extend(list(model.parameters()))
return params
[docs]
def save_video(frames, path, name):
"""
Saves a video containing frames.
Accepts frames in either:
- (T, C, H, W) float in [0,1]
- (T, H, W, C) float in [0,1]
Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png
with per-channel statistics printed to stdout.
"""
import numpy as _np
frames = _np.asarray(frames)
if frames.ndim != 4:
raise ValueError(
f"Expected frames with 4 dims (T, C, H, W) or (T, H, W, C), got shape {frames.shape}"
)
# detect layout
if frames.shape[1] in (1, 3, 4):
# (T, C, H, W)
is_chw = True
elif frames.shape[-1] in (1, 3, 4):
# (T, H, W, C)
is_chw = False
else:
raise ValueError(f"Can't infer channel axis from frames.shape={frames.shape}")
# convert floats -> uint8 and to HWC format for OpenCV
if is_chw:
# (T, C, H, W) -> (T, H, W, C)
frames_u8 = (frames * 255.0).clip(0, 255).astype("uint8").transpose(0, 2, 3, 1)
else:
frames_u8 = (frames * 255.0).clip(0, 255).astype("uint8")
# Basic per-frame / per-channel sanity checks on first frame
if frames_u8.shape[-1] not in (1, 3, 4):
raise ValueError(f"Unexpected channel count: {frames_u8.shape[-1]}")
first = frames_u8[0]
ch = first.shape[-1]
stats = {"min": [], "max": [], "mean": []}
for c in range(ch):
stats["min"].append(int(first[..., c].min()))
stats["max"].append(int(first[..., c].max()))
stats["mean"].append(float(first[..., c].mean()))
equal_ch = False
if ch >= 3:
equal_ch = _np.all(first[..., 0] == first[..., 1]) and _np.all(
first[..., 1] == first[..., 2]
)
out_dir = pathlib.Path(path)
out_dir.mkdir(parents=True, exist_ok=True)
debug_path = out_dir / f"{name}_debug_frame.png"
try:
to_write = first
if first.ndim == 2 or first.shape[-1] == 1:
to_write = _np.repeat(first[..., None], 3, axis=-1)
cv2.imwrite(str(debug_path), to_write[..., ::-1])
except Exception as e:
print(f"Failed to write debug frame PNG: {e}")
print(
f"[save_video] frames.shape={frames.shape} inferred_chw={is_chw} -> written_frame_shape={frames_u8.shape}"
)
print(
f"[save_video] first_frame stats min={stats['min']} max={stats['max']} mean={stats['mean']} equal_rgb_channels={equal_ch}"
)
print(f"[save_video] debug PNG saved to: {debug_path}")
# Write video with OpenCV (expecting HWC uint8 BGR)
H, W = (None, None)
if frames_u8.ndim == 4:
H, W = frames_u8.shape[1], frames_u8.shape[2]
else:
raise RuntimeError("Unexpected frames_u8 shape after conversion")
writer = cv2.VideoWriter(
str(out_dir / f"{name}.mp4"),
cv2.VideoWriter_fourcc(*"mp4v"),
25.0,
(W, H),
True,
)
try:
for frame in frames_u8:
# ensure contiguous HWC uint8
if not frame.flags["C_CONTIGUOUS"]:
frame = _np.ascontiguousarray(frame)
# OpenCV expects BGR
writer.write(frame[..., ::-1])
finally:
writer.release()
return str(out_dir / f"{name}.mp4")
[docs]
def combine_videos(
video_dir, output_name="combined.mp4", pattern="vid_*.mp4", fps=25, resize=True
):
"""
Combine all videos matching `pattern` in `video_dir` into a single MP4 file.
Returns the output filepath (string).
Example:
combine_videos("results/planet", output_name="all_training.mp4")
"""
import glob
files = sorted(glob.glob(os.path.join(video_dir, pattern)))
if len(files) == 0:
raise FileNotFoundError(f"No videos found in {video_dir} matching {pattern}")
# probe first video for size
cap0 = cv2.VideoCapture(files[0])
if not cap0.isOpened():
cap0.release()
raise RuntimeError(f"Failed to open video {files[0]}")
width = int(cap0.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap0.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap0.release()
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out_path = str(pathlib.Path(video_dir) / output_name)
writer = cv2.VideoWriter(out_path, fourcc, float(fps), (width, height), True)
try:
for f in files:
cap = cv2.VideoCapture(f)
if not cap.isOpened():
cap.release()
continue
while True:
ret, frame = cap.read()
if not ret:
break
# frame is BGR height x width x channels
if resize and (frame.shape[1] != width or frame.shape[0] != height):
frame = cv2.resize(frame, (width, height))
writer.write(frame)
cap.release()
finally:
writer.release()
return out_path
[docs]
def ensure_results_dir_exists(results_dir):
"""
Simple helper to validate a results directory exists.
Raises FileNotFoundError if not present.
"""
if not os.path.isdir(results_dir):
raise FileNotFoundError(f"Results directory does not exist: {results_dir}")
[docs]
def save_frames(target, pred_prior, pred_posterior, name, n_rows=5):
"""
Save side-by-side target, prior-prediction, and posterior-prediction frames.
The function accepts tensors with optional time dimension and writes a PNG
grid to ``{name}.png``. Spatial sizes are aligned per timestep before
concatenation and values are normalized to ``[0, 1]`` when needed.
"""
def ensure_time_dim(x):
if not torch.is_tensor(x):
x = torch.tensor(x)
if x.dim() == 3:
return x.unsqueeze(0)
return x
tgt = ensure_time_dim(target).float()
pp = ensure_time_dim(pred_prior).float()
ppp = ensure_time_dim(pred_posterior).float()
T_pred = pp.shape[0]
T_tgt = max(0, tgt.shape[0] - 1)
n = min(T_pred, T_tgt if T_tgt > 0 else T_pred)
if n == 0:
n = min(T_pred, tgt.shape[0])
frames = []
for t in range(n):
if tgt.shape[0] > 1 and t + 1 < tgt.shape[0]:
tf = tgt[t + 1]
else:
tf = tgt[min(t, tgt.shape[0] - 1)]
pr = pp[min(t, pp.shape[0] - 1)]
ppr = ppp[min(t, ppp.shape[0] - 1)]
H, W = tf.shape[1], tf.shape[2]
def match_size(img, H, W):
img4 = img.unsqueeze(0)
if img4.shape[2] != H or img4.shape[3] != W:
img4 = F.interpolate(
img4, size=(H, W), mode="bilinear", align_corners=False
)
return img4.squeeze(0)
tf = match_size(tf, H, W)
pr = match_size(pr, H, W)
ppr = match_size(ppr, H, W)
def clamp01(x):
if x.min() < 0 or x.max() > 1:
return (x - x.min()) / (x.max() - x.min() + 1e-8)
return x
tf = clamp01(tf)
pr = clamp01(pr)
ppr = clamp01(ppr)
# concatenate side-by-side along width (dim=2 because [C,H,W])
row = torch.cat([tf, pr, ppr], dim=2)
frames.append(row)
if len(frames) == 0:
raise RuntimeError(f"No frames to save for {name}")
grid = make_grid(torch.stack(frames), nrow=n_rows, normalize=False)
out_dir = os.path.dirname(name)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
save_image(grid, f"{name}.png")
[docs]
def get_mask(tensor, lengths):
"""
Build a batch-first validity mask from sequence lengths.
``tensor`` may be a tensor/array with shape ``(N, T, ...)`` or ``(N,)``.
The returned mask marks valid timesteps with ones up to each element in
``lengths`` and preserves device/dtype conventions from the input.
"""
# convert numpy -> torch if needed
if not torch.is_tensor(tensor):
tensor = torch.tensor(tensor)
# If tensor has no time axis (e.g. shape [N]), create mask with T = max(lengths)
if tensor.dim() == 1:
N = tensor.shape[0]
T = int(max(lengths))
mask = torch.zeros((N, T), dtype=tensor.dtype, device=tensor.device)
for i in range(N):
mask[i, : int(lengths[i])] = 1.0
return mask
# If tensor has time axis as second dim (batch-first: [N, T, ...])
mask = torch.zeros_like(tensor)
N = tensor.shape[0]
for i in range(N):
mask[i, : int(lengths[i])] = 1.0
return mask
[docs]
def load_memory(path, device):
"""
Loads an experience replay buffer (backwards-compatible with older pickle formats).
Converts legacy list/.data formats into the current Memory(episodes) object.
"""
with open(path, "rb") as f:
memory = pickle.load(f)
# If file contains a plain list of Episode objects -> wrap into Memory
if isinstance(memory, list):
mem = Memory(len(memory))
mem.append(memory)
memory = mem
# If old object had `.data` attribute (legacy) -> convert to Memory
if hasattr(memory, "data"):
try:
mem = Memory(len(memory.data))
mem.append(memory.data)
memory = mem
except Exception:
# fallback: just attach device to elements
for e in memory.data:
setattr(e, "device", device)
# If object already has `.episodes`, ensure devices are set
if hasattr(memory, "episodes"):
for e in memory.episodes:
setattr(e, "device", device)
# Final attach device for the container itself
setattr(memory, "device", device)
return memory
[docs]
def flatten_dict(data, sep=".", prefix=""):
"""Flattens a nested dict into a single-level dict.
Example:
{'a': 2, 'b': {'c': 20}} -> {'a': 2, 'b.c': 20}
"""
def build_key(parent, child):
return f"{parent}{sep}{child}" if parent else str(child)
flattened = {}
for key, val in data.items():
flat_key = build_key(prefix, key)
if isinstance(val, dict):
flattened.update(flatten_dict(val, sep=sep, prefix=key))
else:
flattened[flat_key] = val
return flattened
[docs]
def normalize_frames_for_saving(frames):
"""
Ensure frames are in shape (T, H, W, 3) with float values in [0,1].
Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB,
drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.
"""
import numpy as _np
frames = _np.asarray(frames).astype(_np.float32)
if frames.ndim != 4:
raise ValueError(f"Expected 4D frames array, got shape {frames.shape}")
if frames.shape[1] in (1, 3, 4):
frames = frames.transpose(0, 2, 3, 1)
elif frames.shape[-1] in (1, 3, 4):
pass
else:
raise ValueError(f"Can't infer channel axis from frames.shape={frames.shape}")
ch = frames.shape[-1]
if ch == 1:
frames = _np.repeat(frames, 3, axis=-1)
elif ch == 4:
frames = frames[..., :3]
mn, mx = float(frames.min()), float(frames.max())
if mn >= -0.6 and mx <= 0.6:
frames = (frames + 0.5).clip(0.0, 1.0)
else:
frames = frames.clip(0.0, 1.0)
return frames
[docs]
class TensorBoardMetrics:
"""Plots and (optionally) stores metrics for an experiment."""
def __init__(self, path):
self.writer = SummaryWriter(path)
self.steps = defaultdict(lambda: 0)
self.summary = {}
[docs]
def assign_type(self, key, val):
if isinstance(val, (list, tuple)):
def fun(k, x, s):
self.writer.add_histogram(k, np.array(x), s)
self.summary[key] = fun
elif isinstance(val, (np.ndarray, torch.Tensor)):
self.summary[key] = self.writer.add_histogram
elif isinstance(val, float) or isinstance(val, int):
self.summary[key] = self.writer.add_scalar
else:
raise ValueError(f"Datatype {type(val)} not allowed")
[docs]
def update(self, metrics: dict):
metrics = flatten_dict(metrics)
for key_dots, val in metrics.items():
key = key_dots.replace(".", "/")
if self.summary.get(key, None) is None:
self.assign_type(key, val)
self.summary[key](key, val, self.steps[key])
self.steps[key] += 1
[docs]
def apply_model(model, inputs, ignore_dim=None):
"""Placeholder helper for generic model application across input structures.
Currently not implemented; kept as an extension hook for future utility code.
"""
pass
[docs]
def plot_metrics(metrics, path, prefix):
"""Render and save line plots for each metric series in a dictionary."""
os.makedirs(path, exist_ok=True)
for key, val in metrics.items():
lineplot(np.arange(len(val)), val, f"{prefix}{key}", path)
[docs]
def lineplot(xs, ys, title, path="", xaxis="episode"):
"""Create a Plotly line plot for scalar, dict, or ensemble-series data.
Supports uncertainty-band plotting when `ys` is a 2D array.
"""
MAX_LINE = Line(color="rgb(0, 132, 180)", dash="dash")
MIN_LINE = Line(color="rgb(0, 132, 180)", dash="dash")
NO_LINE = Line(color="rgba(0, 0, 0, 0)")
MEAN_LINE = Line(color="rgb(0, 172, 237)")
std_colour = "rgba(29, 202, 255, 0.2)"
if isinstance(ys, dict):
data = []
for key, val in ys.items():
xs = np.arange(len(val))
data.append(Scatter(x=xs, y=np.array(val), name=key))
elif np.asarray(ys, dtype=np.float32).ndim == 2:
ys = np.asarray(ys, dtype=np.float32)
ys_mean, ys_std = ys.mean(-1), ys.std(-1)
ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std
l_max = Scatter(x=xs, y=ys.max(-1), line=MAX_LINE, name="Max")
l_min = Scatter(x=xs, y=ys.min(-1), line=MIN_LINE, name="Min")
l_stu = Scatter(x=xs, y=ys_upper, line=NO_LINE, showlegend=False)
l_mean = Scatter(
x=xs,
y=ys_mean,
fill="tonexty",
fillcolor=std_colour,
line=MEAN_LINE,
name="Mean",
)
l_stl = Scatter(
x=xs,
y=ys_lower,
fill="tonexty",
fillcolor=std_colour,
line=NO_LINE,
name="-1 Std. Dev.",
showlegend=False,
)
data = [l_stu, l_mean, l_stl, l_min, l_max]
else:
data = [Scatter(x=xs, y=ys, line=MEAN_LINE)]
plotly.offline.plot(
{
"data": data,
"layout": dict(title=title, xaxis={"title": xaxis}, yaxis={"title": title}),
},
filename=os.path.join(path, title + ".html"),
auto_open=False,
)
[docs]
class TorchImageEnvWrapper:
"""
Torch Env Wrapper that wraps a gym env and makes interactions using Tensors.
Also returns observations in image form.
"""
def __init__(self, env, bit_depth, observation_shape=None, act_rep=2):
if isinstance(env, str):
try:
self.env = gym.make(env, render_mode="rgb_array")
self._render_mode_supported = True
except TypeError:
self.env = gym.make(env)
self._render_mode_supported = False
else:
self.env = env
self._render_mode_supported = True
self.bit_depth = bit_depth
self.action_repeats = act_rep
def _get_frame(self, last_obs=None):
"""Call env.render robustly across gym versions. Returns ndarray frame or None.
If rendering fails (OverflowError / pygame Surface) fallback to last_obs or a
synthesized image from the observation vector.
"""
frame = None
try:
out = self.env.render()
except Exception:
out = None
# gym may return (frame, info)
if isinstance(out, tuple):
out = out[0]
if isinstance(out, np.ndarray):
frame = out
else:
# try alternative render signature
try:
out = self.env.render(mode="rgb_array")
if isinstance(out, tuple):
out = out[0]
if isinstance(out, np.ndarray):
frame = out
except Exception:
frame = None
# If frame is still not ndarray, and last_obs is image-like, use that
if frame is None and last_obs is not None:
frame = self._obs_to_frame(
last_obs if not isinstance(last_obs, tuple) else last_obs[0]
)
# If still none, synthesize a simple visualization from a 1D state vector
if frame is None and last_obs is not None:
try:
obs = last_obs if not isinstance(last_obs, tuple) else last_obs[0]
arr = np.asarray(obs)
# if simple vector -> make a 64x64 RGB gradient from values
if arr.ndim == 1 or (arr.ndim == 2 and 1 in arr.shape):
vals = (
(arr.flatten() - arr.min())
if arr.max() != arr.min()
else arr.flatten()
)
if vals.size == 0:
vals = np.zeros(1)
vals = (vals - vals.min()) / (vals.max() - vals.min() + 1e-8)
canvas = np.zeros((64, 64, 3), dtype=np.uint8)
for i, v in enumerate(vals[:8]): # encode first few dims as bands
band = int(255 * v)
canvas[:, i * 8 : (i + 1) * 8, :] = band
frame = canvas
except Exception:
frame = None
return frame
def _obs_to_frame(self, obs):
"""Convert common observation formats to an HWC image when possible."""
if isinstance(obs, tuple):
obs = obs[0]
if isinstance(obs, dict):
for key in ("image", "pixels", "rgb", "observation"):
if key in obs:
frame = self._obs_to_frame(obs[key])
if frame is not None:
return frame
for value in obs.values():
frame = self._obs_to_frame(value)
if frame is not None:
return frame
return None
if not isinstance(obs, np.ndarray):
try:
obs = np.asarray(obs)
except Exception:
return None
if obs.ndim == 3:
frame = obs
if frame.shape[-1] not in (1, 3, 4) and frame.shape[0] in (1, 3, 4):
frame = frame.transpose(1, 2, 0)
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
elif frame.shape[-1] == 4:
frame = frame[..., :3]
return frame
if obs.ndim == 2:
return np.repeat(obs[..., None], 3, axis=-1)
if obs.ndim == 1:
vals = obs.astype(np.float32).reshape(-1)
if vals.size == 0:
return None
if vals.max() != vals.min():
vals = (vals - vals.min()) / (vals.max() - vals.min() + 1e-8)
else:
vals = np.zeros_like(vals)
canvas = np.zeros((64, 64, 3), dtype=np.uint8)
for i, v in enumerate(vals[:8]):
canvas[:, i * 8 : (i + 1) * 8, :] = int(255 * float(v))
return canvas
return None
[docs]
def reset(self):
ret = self.env.reset()
obs = ret[0] if isinstance(ret, tuple) else ret
frame = self._obs_to_frame(obs)
if frame is None:
frame = self._get_frame(last_obs=obs)
if frame is None:
raise RuntimeError(
"Environment did not provide an RGB frame on reset. "
"Use an env that supports image rendering or instantiate with render_mode='rgb_array'."
)
x = to_tensor_obs(frame)
preprocess_img(x, self.bit_depth)
return x
[docs]
def step(self, u):
if isinstance(u, torch.Tensor):
u_t = u.cpu().detach()
else:
u_t = u
if getattr(self.env.action_space, "n", None) is not None:
n = int(self.env.action_space.n)
arr = u_t.numpy() if isinstance(u_t, torch.Tensor) else np.asarray(u_t)
arr = arr.reshape(-1)
if arr.size > 1:
action = int(np.argmax(arr))
else:
val = float(arr[0]) if arr.size else 0.0
action = int(np.clip(int(round(val)), 0, n - 1))
else:
action = u_t.numpy() if isinstance(u_t, torch.Tensor) else np.asarray(u_t)
rwds = 0
last_d = False
last_i = {}
last_obs = None
for _ in range(self.action_repeats):
ret = self.env.step(action)
if len(ret) == 4:
obs, r, d, i = ret
else:
obs, r, term, trunc, i = ret
d = term or trunc
rwds += r
last_d = d
last_i = i
last_obs = obs
frame = self._obs_to_frame(
last_obs if not isinstance(last_obs, tuple) else last_obs[0]
)
if frame is None:
frame = self._get_frame(last_obs=last_obs)
if frame is None:
raise RuntimeError(
"Environment did not provide an RGB frame on step. "
"Use an env that supports image rendering or instantiate with render_mode='rgb_array'."
)
x = to_tensor_obs(frame)
preprocess_img(x, self.bit_depth)
return x, rwds, last_d, last_i
[docs]
def render(self):
self.env.render()
[docs]
def close(self):
self.env.close()
@property
def observation_size(self):
return (3, 64, 64)
@property
def action_size(self):
space = getattr(self.env, "action_space", None)
if space is None:
return 1
shp = getattr(space, "shape", None)
if shp and len(shp) > 0:
try:
prod = int(np.prod(shp))
return prod if prod > 0 else 1
except Exception:
pass
if getattr(space, "n", None) is not None:
return 1
try:
sample = space.sample()
arr = np.asarray(sample)
if arr.ndim == 0:
return 1
return int(arr.size)
except Exception:
return 1
[docs]
def sample_random_action(self):
return torch.tensor(self.env.action_space.sample())
@property
def max_episode_steps(self):
"""Return environment max episode steps (compatible with TimeLimit/spec)."""
if (
hasattr(self.env, "_max_episode_steps")
and self.env._max_episode_steps is not None
):
return self.env._max_episode_steps
if (
getattr(self.env, "spec", None) is not None
and getattr(self.env.spec, "max_episode_steps", None) is not None
):
return int(self.env.spec.max_episode_steps)
return 1000
[docs]
def apply_masks(x, masks):
"""Gather token subsets from patch sequences using index masks.
Each mask selects token positions from `x`; selected groups are concatenated
along the batch dimension.
"""
all_x = []
for m in masks:
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.shape[-1])
all_x.append(torch.gather(x, 1, mask_keep))
return torch.cat(all_x, dim=0)