Source code for world_models.utils.utils

import os
import cv2
import gym
import torch
import pickle
import pathlib
import numpy as np
import glob
import warnings
from typing import Any, Optional, Dict, List


import plotly
from plotly.graph_objs import Scatter, Line

from collections import defaultdict
from world_models.memory.planet_memory import Memory

from torchvision.utils import make_grid, save_image

import torch.nn.functional as F

import yaml

import collections
import collections.abc

from sklearn.manifold import TSNE
import umap

HAS_VIZ = True


class _RestrictedReplayUnpickler(pickle.Unpickler):
    """Unpickler that only resolves classes needed by replay buffers."""

    _ALLOWED_GLOBALS = {
        ("builtins", "dict"),
        ("builtins", "list"),
        ("builtins", "set"),
        ("builtins", "slice"),
        ("builtins", "tuple"),
        ("collections", "deque"),
        ("numpy", "dtype"),
        ("numpy", "ndarray"),
        ("numpy.core.multiarray", "_reconstruct"),
        ("numpy._core.multiarray", "_reconstruct"),
        ("world_models.memory.planet_memory", "Episode"),
        ("world_models.memory.planet_memory", "Memory"),
        ("world_models.memory.planet_memory", "_identity"),
    }

    def find_class(self, module: str, name: str) -> type:
        if (module, name) in self._ALLOWED_GLOBALS:
            return super().find_class(module, name)
        raise pickle.UnpicklingError(
            f"global '{module}.{name}' is not allowed in replay buffers"
        )


[docs] class AttrDict(dict): def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: del self[name]
[docs] def load_yml_config(path: str) -> AttrDict | None: with open(path) as fileStream: loaded = yaml.safe_load(fileStream) keys = list(loaded.keys()) dictionary = None for i, key in enumerate(keys): if i == 0: dictionary = AttrDict({key: loaded[key]}) else: dictionary += AttrDict({key: loaded[key]}) return dictionary
[docs] def to_tensor_obs(image: np.ndarray) -> torch.Tensor: """ Converts the input np img to channel first 64x64 dim torch img. """ image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_LINEAR) tensor: torch.Tensor = torch.from_numpy(image).float().permute(2, 0, 1) return tensor
[docs] def postprocess_img(image: np.ndarray, depth: int) -> np.ndarray: """ Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255]) """ image = np.floor((image + 0.5) * 2**depth) return np.clip(image * 2 ** (8 - depth), 0, 2**8 - 1).astype(np.uint8)
[docs] def preprocess_img(image: torch.Tensor, depth: int) -> None: """ Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !! """ image.div_(2 ** (8 - depth)).floor_().div_(2**depth).sub_(0.5) image.add_(torch.randn_like(image).div_(2**depth)).clamp_(-0.5, 0.5)
[docs] def bottle(func: Any, *tensors: torch.Tensor) -> torch.Tensor: """ Evaluates a func that operates in N x D with inputs of shape N x T x D """ n, t = tensors[0].shape[:2] inputs = [x.reshape(n * t, *x.shape[2:]) for x in tensors] out = func(*inputs) return out.view(n, t, *out.shape[1:])
[docs] def get_combined_params(*models: Any) -> list: """ Returns the combine parameter list of all the models given as input. """ params = [] for model in models: params.extend(list(model.parameters())) return params
[docs] def save_video(frames: Any, path: str, name: str) -> str: """ Saves a video containing frames. Accepts frames in either: - (T, C, H, W) float in [0,1] - (T, H, W, C) float in [0,1] Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout. """ # use top-level numpy (np) to keep imports top-level frames = np.asarray(frames) if frames.ndim != 4: raise ValueError( f"Expected frames with 4 dims (T, C, H, W) or (T, H, W, C), got shape {frames.shape}" ) # detect layout if frames.shape[1] in (1, 3, 4): # (T, C, H, W) is_chw = True elif frames.shape[-1] in (1, 3, 4): # (T, H, W, C) is_chw = False else: raise ValueError(f"Can't infer channel axis from frames.shape={frames.shape}") # convert floats -> uint8 and to HWC format for OpenCV if is_chw: # (T, C, H, W) -> (T, H, W, C) frames_u8 = (frames * 255.0).clip(0, 255).astype("uint8").transpose(0, 2, 3, 1) else: frames_u8 = (frames * 255.0).clip(0, 255).astype("uint8") # Basic per-frame / per-channel sanity checks on first frame if frames_u8.shape[-1] not in (1, 3, 4): raise ValueError(f"Unexpected channel count: {frames_u8.shape[-1]}") first = frames_u8[0] ch = first.shape[-1] stats: Dict[str, List[Any]] = {"min": [], "max": [], "mean": []} for c in range(ch): stats["min"].append(int(first[..., c].min())) stats["max"].append(int(first[..., c].max())) stats["mean"].append(float(first[..., c].mean())) equal_ch: bool = False if ch >= 3: equal_ch = bool( np.all(first[..., 0] == first[..., 1]) and np.all(first[..., 1] == first[..., 2]) ) out_dir = pathlib.Path(path) out_dir.mkdir(parents=True, exist_ok=True) debug_path = out_dir / f"{name}_debug_frame.png" try: to_write = first if first.ndim == 2 or first.shape[-1] == 1: to_write = np.repeat(first[..., None], 3, axis=-1) cv2.imwrite(str(debug_path), to_write[..., ::-1]) except Exception as e: print(f"Failed to write debug frame PNG: {e}") print( f"[save_video] frames.shape={frames.shape} inferred_chw={is_chw} -> written_frame_shape={frames_u8.shape}" ) print( f"[save_video] first_frame stats min={stats['min']} max={stats['max']} mean={stats['mean']} equal_rgb_channels={equal_ch}" ) print(f"[save_video] debug PNG saved to: {debug_path}") # Write video with OpenCV (expecting HWC uint8 BGR) H, W = (None, None) if frames_u8.ndim == 4: H, W = frames_u8.shape[1], frames_u8.shape[2] else: raise RuntimeError("Unexpected frames_u8 shape after conversion") writer = cv2.VideoWriter( # type: ignore[attr-defined] str(out_dir / f"{name}.mp4"), cv2.VideoWriter_fourcc(*"mp4v"), 25.0, (W, H), True, ) try: for frame in frames_u8: # ensure contiguous HWC uint8 if not frame.flags["C_CONTIGUOUS"]: frame = np.ascontiguousarray(frame) # OpenCV expects BGR writer.write(frame[..., ::-1]) finally: writer.release() return str(out_dir / f"{name}.mp4")
[docs] def combine_videos( video_dir: str, output_name: str = "combined.mp4", pattern: str = "vid_*.mp4", fps: int = 25, resize: bool = True, ) -> str: """ Combine all videos matching `pattern` in `video_dir` into a single MP4 file. Returns the output filepath (string). Example: combine_videos("results/planet", output_name="all_training.mp4") """ files = sorted(glob.glob(os.path.join(video_dir, pattern))) if len(files) == 0: raise FileNotFoundError(f"No videos found in {video_dir} matching {pattern}") # probe first video for size cap0 = cv2.VideoCapture(files[0]) if not cap0.isOpened(): cap0.release() raise RuntimeError(f"Failed to open video {files[0]}") width = int(cap0.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap0.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap0.release() fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore[attr-defined] out_path = str(pathlib.Path(video_dir) / output_name) writer = cv2.VideoWriter(out_path, fourcc, float(fps), (width, height), True) try: for f in files: cap = cv2.VideoCapture(f) if not cap.isOpened(): cap.release() continue while True: ret, frame = cap.read() if not ret: break # frame is BGR height x width x channels if resize and (frame.shape[1] != width or frame.shape[0] != height): frame = cv2.resize(frame, (width, height)) writer.write(frame) cap.release() finally: writer.release() return out_path
[docs] def ensure_results_dir_exists(results_dir: str) -> None: """ Simple helper to validate a results directory exists. Raises FileNotFoundError if not present. """ if not os.path.isdir(results_dir): raise FileNotFoundError(f"Results directory does not exist: {results_dir}")
[docs] def save_frames( target: torch.Tensor, pred_prior: torch.Tensor, pred_posterior: torch.Tensor, name: str, n_rows: int = 5, ) -> None: """ Save side-by-side target, prior-prediction, and posterior-prediction frames. The function accepts tensors with optional time dimension and writes a PNG grid to ``{name}.png``. Spatial sizes are aligned per timestep before concatenation and values are normalized to ``[0, 1]`` when needed. """ def ensure_time_dim(x: torch.Tensor | np.ndarray) -> torch.Tensor: if not torch.is_tensor(x): x = torch.tensor(x) if x.dim() == 3: return x.unsqueeze(0) return x tgt = ensure_time_dim(target).float() pp = ensure_time_dim(pred_prior).float() ppp = ensure_time_dim(pred_posterior).float() T_pred = pp.shape[0] T_tgt = max(0, tgt.shape[0] - 1) n = min(T_pred, T_tgt if T_tgt > 0 else T_pred) if n == 0: n = min(T_pred, tgt.shape[0]) frames = [] for t in range(n): if tgt.shape[0] > 1 and t + 1 < tgt.shape[0]: tf = tgt[t + 1] else: tf = tgt[min(t, tgt.shape[0] - 1)] pr = pp[min(t, pp.shape[0] - 1)] ppr = ppp[min(t, ppp.shape[0] - 1)] H, W = tf.shape[1], tf.shape[2] def match_size(img: torch.Tensor, H: int, W: int) -> torch.Tensor: img4 = img.unsqueeze(0) if img4.shape[2] != H or img4.shape[3] != W: img4 = F.interpolate( img4, size=(H, W), mode="bilinear", align_corners=False ) return img4.squeeze(0) tf = match_size(tf, H, W) pr = match_size(pr, H, W) ppr = match_size(ppr, H, W) def clamp01(x: torch.Tensor) -> torch.Tensor: if x.min() < 0 or x.max() > 1: return (x - x.min()) / (x.max() - x.min() + 1e-8) return x tf = clamp01(tf) pr = clamp01(pr) ppr = clamp01(ppr) # concatenate side-by-side along width (dim=2 because [C,H,W]) row = torch.cat([tf, pr, ppr], dim=2) frames.append(row) if len(frames) == 0: raise RuntimeError(f"No frames to save for {name}") grid = make_grid(torch.stack(frames), nrow=n_rows, normalize=False) out_dir = os.path.dirname(name) if out_dir: os.makedirs(out_dir, exist_ok=True) save_image(grid, f"{name}.png")
[docs] def get_mask(tensor: Any, lengths: Any) -> torch.Tensor: """ Build a batch-first validity mask from sequence lengths. ``tensor`` may be a tensor/array with shape ``(N, T, ...)`` or ``(N,)``. The returned mask marks valid timesteps with ones up to each element in ``lengths`` and preserves device/dtype conventions from the input. """ # convert numpy -> torch if needed if not torch.is_tensor(tensor): tensor = torch.tensor(tensor) # If tensor has no time axis (e.g. shape [N]), create mask with T = max(lengths) if tensor.dim() == 1: N = tensor.shape[0] T = int(max(lengths)) mask = torch.zeros((N, T), dtype=tensor.dtype, device=tensor.device) for i in range(N): mask[i, : int(lengths[i])] = 1.0 return mask # If tensor has time axis as second dim (batch-first: [N, T, ...]) mask = torch.zeros_like(tensor) N = tensor.shape[0] for i in range(N): mask[i, : int(lengths[i])] = 1.0 return mask
[docs] def load_memory(path: str, device: torch.device, *, trusted: bool = False) -> Any: """ Loads an experience replay buffer. Pickle can execute arbitrary code during unrestricted deserialization, so user-supplied replay buffers are always loaded with a restricted unpickler that only allows the replay buffer classes and numpy containers required by historical buffers. The ``trusted`` argument is retained for backwards compatibility, but it no longer enables unrestricted pickle loading. Converts legacy list/.data formats into the current Memory(episodes) object. """ if trusted: warnings.warn( "load_memory(trusted=True) is deprecated and no longer enables " "unrestricted pickle loading.", DeprecationWarning, stacklevel=2, ) with open(path, "rb") as f: memory = _RestrictedReplayUnpickler(f).load() # If file contains a plain list of Episode objects -> wrap into Memory if isinstance(memory, list): mem = Memory(len(memory)) mem.append(memory) memory = mem # If old object had `.data` attribute (legacy) -> convert to Memory if hasattr(memory, "data"): try: mem = Memory(len(memory.data)) mem.append(memory.data) memory = mem except Exception: # fallback: just attach device to elements for e in memory.data: setattr(e, "device", device) # If object already has `.episodes`, ensure devices are set if hasattr(memory, "episodes"): for e in memory.episodes: setattr(e, "device", device) # Final attach device for the container itself setattr(memory, "device", device) return memory
[docs] def flatten_dict(data: dict, sep: str = ".", prefix: str = "") -> dict: """Flattens a nested dict into a single-level dict. Example: {'a': 2, 'b': {'c': 20}} -> {'a': 2, 'b.c': 20} """ def build_key(parent: str, child: str) -> str: return f"{parent}{sep}{child}" if parent else str(child) flattened = {} for key, val in data.items(): flat_key = build_key(prefix, key) if isinstance(val, dict): flattened.update(flatten_dict(val, sep=sep, prefix=key)) else: flattened[flat_key] = val return flattened
[docs] def normalize_frames_for_saving(frames: Any) -> np.ndarray: """ Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected. """ # use top-level numpy (np) to keep imports at module scope frames = np.asarray(frames).astype(np.float32) if frames.ndim != 4: raise ValueError(f"Expected 4D frames array, got shape {frames.shape}") if frames.shape[1] in (1, 3, 4): frames = frames.transpose(0, 2, 3, 1) elif frames.shape[-1] in (1, 3, 4): pass else: raise ValueError(f"Can't infer channel axis from frames.shape={frames.shape}") ch = frames.shape[-1] if ch == 1: frames = np.repeat(frames, 3, axis=-1) elif ch == 4: frames = frames[..., :3] mn, mx = float(frames.min()), float(frames.max()) if mn >= -0.6 and mx <= 0.6: frames = (frames + 0.5).clip(0.0, 1.0) else: frames = frames.clip(0.0, 1.0) return frames
[docs] class TensorBoardMetrics: """Plots and (optionally) stores metrics for an experiment.""" def __init__(self, path: str) -> None: self.steps: defaultdict[str, int] = defaultdict(lambda: 0) self.summary: dict[str, Any] = {}
[docs] def assign_type(self, key: str, val: Any) -> None: pass
[docs] def update(self, metrics: dict) -> None: metrics = flatten_dict(metrics) for key_dots, val in metrics.items(): key = key_dots.replace(".", "/") if self.summary.get(key, None) is None: self.assign_type(key, val) self.steps[key] += 1
[docs] def apply_model(model: Any, inputs: Any, ignore_dim: Any = None) -> None: """Placeholder helper for generic model application across input structures. Currently not implemented; kept as an extension hook for future utility code. """ pass
[docs] def plot_metrics(metrics: dict, path: str, prefix: str) -> None: """Render and save line plots for each metric series in a dictionary.""" os.makedirs(path, exist_ok=True) for key, val in metrics.items(): lineplot(np.arange(len(val)), val, f"{prefix}{key}", path)
[docs] def lineplot( xs: np.ndarray | list, ys: Any, title: str, path: str = "", xaxis: str = "episode", ) -> None: """Create a Plotly line plot for scalar, dict, or ensemble-series data. Supports uncertainty-band plotting when `ys` is a 2D array. """ MAX_LINE = Line(color="rgb(0, 132, 180)", dash="dash") MIN_LINE = Line(color="rgb(0, 132, 180)", dash="dash") NO_LINE = Line(color="rgba(0, 0, 0, 0)") MEAN_LINE = Line(color="rgb(0, 172, 237)") std_colour = "rgba(29, 202, 255, 0.2)" if isinstance(ys, dict): data = [] for key, val in ys.items(): xs = np.arange(len(val)) data.append(Scatter(x=xs, y=np.array(val), name=key)) elif np.asarray(ys, dtype=np.float32).ndim == 2: ys = np.asarray(ys, dtype=np.float32) ys_mean, ys_std = ys.mean(-1), ys.std(-1) ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std l_max = Scatter(x=xs, y=ys.max(-1), line=MAX_LINE, name="Max") l_min = Scatter(x=xs, y=ys.min(-1), line=MIN_LINE, name="Min") l_stu = Scatter(x=xs, y=ys_upper, line=NO_LINE, showlegend=False) l_mean = Scatter( x=xs, y=ys_mean, fill="tonexty", fillcolor=std_colour, line=MEAN_LINE, name="Mean", ) l_stl = Scatter( x=xs, y=ys_lower, fill="tonexty", fillcolor=std_colour, line=NO_LINE, name="-1 Std. Dev.", showlegend=False, ) data = [l_stu, l_mean, l_stl, l_min, l_max] else: data = [Scatter(x=xs, y=ys, line=MEAN_LINE)] plotly.offline.plot( { "data": data, "layout": dict(title=title, xaxis={"title": xaxis}, yaxis={"title": title}), }, filename=os.path.join(path, title + ".html"), auto_open=False, )
[docs] class TorchImageEnvWrapper: """ Torch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form. """ def __init__( self, env: Any, bit_depth: int, observation_shape: Any = None, act_rep: int = 2 ) -> None: if isinstance(env, str): try: self.env = gym.make(env, render_mode="rgb_array") self._render_mode_supported = True except ImportError as exc: from world_models.envs.robotics_env import ( is_moved_mujoco_error, make_robotics_env, ) if not is_moved_mujoco_error(exc): raise self.env = make_robotics_env(env, render_mode="rgb_array") self._render_mode_supported = True except TypeError: self.env = gym.make(env) self._render_mode_supported = False else: self.env = env self._render_mode_supported = True self.bit_depth = bit_depth self.action_repeats = act_rep def _get_frame(self, last_obs: Any = None) -> np.ndarray | None: """Call env.render robustly across gym versions. Returns ndarray frame or None. If rendering fails (OverflowError / pygame Surface) fallback to last_obs or a synthesized image from the observation vector. """ frame = None try: out = self.env.render() except Exception: out = None # gym may return (frame, info) if isinstance(out, tuple): out = out[0] if isinstance(out, np.ndarray): frame = out else: # try alternative render signature try: out = self.env.render(mode="rgb_array") if isinstance(out, tuple): out = out[0] if isinstance(out, np.ndarray): frame = out except Exception: frame = None # If frame is still not ndarray, and last_obs is image-like, use that if frame is None and last_obs is not None: frame = self._obs_to_frame( last_obs if not isinstance(last_obs, tuple) else last_obs[0] ) # If still none, synthesize a simple visualization from a 1D state vector if frame is None and last_obs is not None: try: obs = last_obs if not isinstance(last_obs, tuple) else last_obs[0] arr = np.asarray(obs) # if simple vector -> make a 64x64 RGB gradient from values if arr.ndim == 1 or (arr.ndim == 2 and 1 in arr.shape): vals = ( (arr.flatten() - arr.min()) if arr.max() != arr.min() else arr.flatten() ) if vals.size == 0: vals = np.zeros(1) vals = (vals - vals.min()) / (vals.max() - vals.min() + 1e-8) canvas = np.zeros((64, 64, 3), dtype=np.uint8) for i, v in enumerate(vals[:8]): # encode first few dims as bands band = int(255 * v) canvas[:, i * 8 : (i + 1) * 8, :] = band frame = canvas except Exception: frame = None return frame def _obs_to_frame(self, obs: Any) -> np.ndarray | None: """Convert common observation formats to an HWC image when possible.""" if isinstance(obs, tuple): obs = obs[0] if isinstance(obs, dict): for key in ("image", "pixels", "rgb", "observation"): if key in obs: frame = self._obs_to_frame(obs[key]) if frame is not None: return frame for value in obs.values(): frame = self._obs_to_frame(value) if frame is not None: return frame return None if not isinstance(obs, np.ndarray): try: obs = np.asarray(obs) except Exception: return None if obs.ndim == 3: frame = obs if frame.shape[-1] not in (1, 3, 4) and frame.shape[0] in (1, 3, 4): frame = frame.transpose(1, 2, 0) if frame.shape[-1] == 1: frame = np.repeat(frame, 3, axis=-1) elif frame.shape[-1] == 4: frame = frame[..., :3] return frame if obs.ndim == 2: return np.repeat(obs[..., None], 3, axis=-1) if obs.ndim == 1: vals = obs.astype(np.float32).reshape(-1) if vals.size == 0: return None if vals.max() != vals.min(): vals = (vals - vals.min()) / (vals.max() - vals.min() + 1e-8) else: vals = np.zeros_like(vals) canvas = np.zeros((64, 64, 3), dtype=np.uint8) for i, v in enumerate(vals[:8]): canvas[:, i * 8 : (i + 1) * 8, :] = int(255 * float(v)) return canvas return None
[docs] def reset(self) -> torch.Tensor: ret = self.env.reset() obs = ret[0] if isinstance(ret, tuple) else ret frame = self._obs_to_frame(obs) if frame is None: frame = self._get_frame(last_obs=obs) if frame is None: raise RuntimeError( "Environment did not provide an RGB frame on reset. " "Use an env that supports image rendering or instantiate with render_mode='rgb_array'." ) x = to_tensor_obs(frame) preprocess_img(x, self.bit_depth) return x
[docs] def step(self, u: torch.Tensor | np.ndarray | list | float | int) -> tuple: if isinstance(u, torch.Tensor): u_t = u.cpu().detach() else: u_t = u if getattr(self.env.action_space, "n", None) is not None: n = int(self.env.action_space.n) arr = u_t.numpy() if isinstance(u_t, torch.Tensor) else np.asarray(u_t) arr = arr.reshape(-1) if arr.size > 1: action = int(np.argmax(arr)) else: val = float(arr[0]) if arr.size else 0.0 action = int(np.clip(int(round(val)), 0, n - 1)) else: action = u_t.numpy() if isinstance(u_t, torch.Tensor) else np.asarray(u_t) rwds = 0 last_d = False last_i = {} last_obs = None for _ in range(self.action_repeats): ret = self.env.step(action) if len(ret) == 4: obs, r, d, i = ret else: obs, r, term, trunc, i = ret d = term or trunc rwds += r last_d = d last_i = i last_obs = obs frame = self._obs_to_frame( last_obs if not isinstance(last_obs, tuple) else last_obs[0] ) if frame is None: frame = self._get_frame(last_obs=last_obs) if frame is None: raise RuntimeError( "Environment did not provide an RGB frame on step. " "Use an env that supports image rendering or instantiate with render_mode='rgb_array'." ) x = to_tensor_obs(frame) preprocess_img(x, self.bit_depth) return x, rwds, last_d, last_i
[docs] def render(self) -> None: self.env.render()
[docs] def close(self) -> None: self.env.close()
@property def observation_size(self) -> tuple[int, int, int]: return (3, 64, 64) @property def action_size(self) -> int: space = getattr(self.env, "action_space", None) if space is None: return 1 shp = getattr(space, "shape", None) if shp and len(shp) > 0: try: prod = int(np.prod(shp)) return prod if prod > 0 else 1 except Exception: pass if getattr(space, "n", None) is not None: return 1 try: sample = space.sample() arr = np.asarray(sample) if arr.ndim == 0: return 1 return int(arr.size) except Exception: return 1
[docs] def sample_random_action(self) -> torch.Tensor: return torch.tensor(self.env.action_space.sample())
@property def max_episode_steps(self) -> int: """Return environment max episode steps (compatible with TimeLimit/spec).""" if ( hasattr(self.env, "_max_episode_steps") and self.env._max_episode_steps is not None ): return self.env._max_episode_steps if ( getattr(self.env, "spec", None) is not None and getattr(self.env.spec, "max_episode_steps", None) is not None ): return int(self.env.spec.max_episode_steps) return 1000
[docs] def apply_masks(x: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor: """Gather token subsets from patch sequences using index masks. Each mask selects token positions from `x`; selected groups are concatenated along the batch dimension. """ all_x = [] for m in masks: mask_keep = m.unsqueeze(-1).repeat(1, 1, x.shape[-1]) all_x.append(torch.gather(x, 1, mask_keep)) return torch.cat(all_x, dim=0)
[docs] def visualize_latent_tsne( latents: torch.Tensor | np.ndarray, labels: np.ndarray | None = None, save_path: str | None = None, perplexity: int = 30, ) -> plotly.graph_objs.Figure: """ Visualize latent representations using t-SNE. Args: latents: torch.Tensor of shape (N, D) or numpy array labels: optional list or array of labels for coloring save_path: path to save the plot (HTML for plotly) perplexity: t-SNE perplexity parameter """ if isinstance(latents, torch.Tensor): latents = latents.detach().cpu().numpy() tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42) latents_2d = tsne.fit_transform(latents) fig = plotly.graph_objs.Figure() if labels is not None: unique_labels = np.unique(labels) for label in unique_labels: idx = labels == label fig.add_trace( Scatter( x=latents_2d[idx, 0], y=latents_2d[idx, 1], mode="markers", name=str(label), marker=dict(size=5), ) ) else: fig.add_trace( Scatter( x=latents_2d[:, 0], y=latents_2d[:, 1], mode="markers", marker=dict(size=5), ) ) fig.update_layout( title="Latent Space t-SNE", xaxis_title="t-SNE 1", yaxis_title="t-SNE 2" ) if save_path: fig.write_html(save_path) return fig
[docs] def visualize_latent_umap( latents: torch.Tensor | np.ndarray, labels: np.ndarray | None = None, save_path: str | None = None, n_neighbors: int = 15, ) -> plotly.graph_objs.Figure: """ Visualize latent representations using UMAP. Args: latents: torch.Tensor of shape (N, D) or numpy array labels: optional list or array of labels for coloring save_path: path to save the plot (HTML for plotly) n_neighbors: UMAP n_neighbors parameter """ if isinstance(latents, torch.Tensor): latents = latents.detach().cpu().numpy() reducer = umap.UMAP(n_neighbors=n_neighbors, random_state=42) latents_2d = reducer.fit_transform(latents) fig = plotly.graph_objs.Figure() if labels is not None: unique_labels = np.unique(labels) for label in unique_labels: idx = labels == label fig.add_trace( Scatter( x=latents_2d[idx, 0], y=latents_2d[idx, 1], mode="markers", name=str(label), marker=dict(size=5), ) ) else: fig.add_trace( Scatter( x=latents_2d[:, 0], y=latents_2d[:, 1], mode="markers", marker=dict(size=5), ) ) fig.update_layout( title="Latent Space UMAP", xaxis_title="UMAP 1", yaxis_title="UMAP 2" ) if save_path: fig.write_html(save_path) return fig
[docs] class StreamingVideoWriter: """ A class for streaming video writing to save frames in real-time. Args: path: output video file path fps: frames per second frame_shape: (height, width) of frames format: 'mp4' or 'avi' """ def __init__( self, path: str, fps: int = 20, frame_shape: Any = None, format: str = "mp4" ) -> None: self.path = path self.fps = fps self.frame_shape = frame_shape self.format = format.lower() self.writer = None
[docs] def write_frame(self, frame: np.ndarray) -> None: """ Write a single frame to the video. Args: frame: numpy array of shape (H, W, C) or (H, W), uint8 or float in [0,1] """ if self.writer is None: if self.frame_shape is None: self.frame_shape = frame.shape[:2][::-1] # (W, H) if self.format == "mp4": fourcc = cv2.VideoWriter_fourcc(*"mp4v") elif self.format == "avi": fourcc = cv2.VideoWriter_fourcc(*"XVID") else: raise ValueError("Unsupported format") self.writer = cv2.VideoWriter(self.path, fourcc, self.fps, self.frame_shape) # Convert frame to uint8 HWC BGR if frame.dtype != np.uint8: frame = (frame * 255).astype(np.uint8) if frame.shape[-1] == 3: # RGB to BGR frame = frame[..., ::-1] self.writer.write(frame)
[docs] def close(self) -> None: if self.writer is not None: self.writer.release() self.writer = None