Source code for world_models.datasets.video_datasets

import torch
from torch.utils.data import Dataset, DataLoader
from typing import Optional, Callable, List, Tuple, Union, Any, cast, Sequence
from pathlib import Path
import numpy as np
from dataclasses import dataclass
import logging
import cv2
from PIL import Image
import h5py

logger = logging.getLogger(__name__)


[docs] @dataclass class DatasetConfig: """Base configuration for datasets.""" num_frames: int = 16 image_size: int = 64 batch_size: int = 4 num_workers: int = 4 pin_memory: bool = True shuffle: bool = True
[docs] class VideoDatasetBase(Dataset): """Base class for video datasets. All video datasets should inherit from this class and implement the _load_video method. """ def __init__( self, data_source: Union[str, Path, List[str], List[Path]], num_frames: int = 16, image_size: int = 64, transform: Optional[Callable] = None, normalize: bool = True, ): self.data_source: Union[str, Path, List[str], List[Path]] = data_source self.num_frames = num_frames self.image_size = image_size self.transform = transform self.normalize = normalize # video_paths is a list of Path-like objects or numeric indices for in-memory datasets # Use a union element type so subclasses may return either Path or int indices. # video_paths may be a sequence of Path objects (file-based datasets) # or integer indices (in-memory datasets). Use Sequence for flexibility. self.video_paths: Sequence[Union[Path, int]] = self._get_video_paths() def _get_video_paths(self) -> Sequence[Union[Path, int]]: """Get list of video file paths. Override in subclass.""" raise NotImplementedError def _load_video(self, idx: int) -> torch.Tensor: """Load a single video. Override in subclass.""" raise NotImplementedError def __len__(self) -> int: return len(self.video_paths) def __getitem__(self, idx: int) -> torch.Tensor: video = self._load_video(idx) if self.transform is not None: video = self.transform(video) if self.normalize: video = video / 255.0 if video.max() > 1.0 else video return video
[docs] class VideoFolderDataset(VideoDatasetBase): """Dataset that loads videos from a folder. Supports common video formats: .mp4, .avi, .mkv, .webm Usage: dataset = VideoFolderDataset( data_source="/path/to/videos", num_frames=16, image_size=64 ) """ def __init__( self, data_source: Union[str, Path, List[str], List[Path]], num_frames: int = 16, image_size: int = 64, transform: Optional[Callable] = None, normalize: bool = True, extensions: Tuple[str, ...] = (".mp4", ".avi", ".mkv", ".webm", ".mov"), recursive: bool = True, ): self.extensions = extensions self.recursive = recursive super().__init__(data_source, num_frames, image_size, transform, normalize) def _get_video_paths(self) -> Sequence[Union[Path, int]]: if isinstance(self.data_source, (list, tuple)): return [Path(p) for p in self.data_source] data_path = Path(cast(Union[str, Path], self.data_source)) if not data_path.exists(): raise FileNotFoundError(f"Data path not found: {data_path}") if data_path.is_file(): return [data_path] video_paths: List[Path] = [] if self.recursive: for ext in self.extensions: video_paths.extend(list(data_path.rglob(f"*{ext}"))) else: for ext in self.extensions: video_paths.extend(list(data_path.glob(f"*{ext}"))) return sorted(video_paths) def _load_video(self, idx: int) -> torch.Tensor: video_path = cast(Path, self.video_paths[idx]) cap = cv2.VideoCapture(str(video_path)) frames_list: List[np.ndarray] = [] while True: ret, frame = cap.read() if not ret: break frames_list.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cap.release() frames_arr = np.array(frames_list) frames_arr = self._sample_frames(frames_arr) frames_pil = [Image.fromarray(frame) for frame in frames_arr] frames_resized = [ f.resize((self.image_size, self.image_size)) for f in frames_pil ] frames_array = np.stack([np.array(f) for f in frames_resized]) return torch.from_numpy(frames_array).float() def _sample_frames(self, frames: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: # Accept either an ndarray or a list of frames and normalize to ndarray if isinstance(frames, list): frames = np.array(frames) total_frames = int(len(frames)) if total_frames == self.num_frames: return frames if total_frames < self.num_frames: indices = np.linspace(0, total_frames - 1, self.num_frames).astype(int) return frames[indices] indices = np.linspace(0, total_frames - 1, self.num_frames).astype(int) return frames[indices]
[docs] class ImageFolderDataset(VideoDatasetBase): """Dataset that loads image sequences from folders. Each subfolder is treated as a video sequence. Usage: dataset = ImageFolderDataset( data_source="/path/to/images", num_frames=16, image_size=64 ) """ def __init__( self, data_source: Union[str, Path, List[str], List[Path]], num_frames: int = 16, image_size: int = 64, transform: Optional[Callable] = None, normalize: bool = True, extensions: Tuple[str, ...] = (".jpg", ".jpeg", ".png", ".bmp"), image_sort_key: Optional[Callable] = None, ): self.extensions = extensions self.image_sort_key = image_sort_key or ( lambda x: int(x.stem.split(".")[0]) if x.stem.isdigit() else 0 ) super().__init__(data_source, num_frames, image_size, transform, normalize) def _get_video_paths(self) -> Sequence[Union[Path, int]]: if isinstance(self.data_source, (list, tuple)): return [Path(p) for p in self.data_source] data_path = Path(cast(Union[str, Path], self.data_source)) if not data_path.exists(): raise FileNotFoundError(f"Data path not found: {data_path}") if data_path.is_file(): return [data_path.parent] sequences = [d for d in data_path.iterdir() if d.is_dir()] return sorted(sequences) def _load_video(self, idx: int) -> torch.Tensor: seq_path = cast(Path, self.video_paths[idx]) image_files: List[Path] = [] for ext in self.extensions: image_files.extend(seq_path.glob(f"*{ext}")) image_files = sorted(image_files, key=self.image_sort_key) if len(image_files) == 0: raise ValueError(f"No images found in {seq_path}") frames = [] for img_path in image_files[: self.num_frames]: img = Image.open(img_path).convert("RGB") img = img.resize((self.image_size, self.image_size)) frames.append(np.array(img)) while len(frames) < self.num_frames: frames.append(frames[-1].copy()) frames_array = np.stack(frames[: self.num_frames]) return torch.from_numpy(frames_array).float()
[docs] class NumPyDataset(VideoDatasetBase): """Dataset that loads videos from numpy files. Supports .npy and .npz files. Usage: dataset = NumPyDataset( data_source="/path/to/videos.npy", num_frames=16, image_size=64 ) """ def __init__( self, data_source: Union[str, Path], num_frames: int = 16, image_size: int = 64, transform: Optional[Callable] = None, normalize: bool = True, key: Optional[str] = None, ): self.key = key data_path = Path(data_source) if data_path.suffix == ".npz": self.npz_data = np.load(data_path, allow_pickle=True) if self.key is None: self.key = list(self.npz_data.keys())[0] data = self.npz_data[self.key] self.num_samples = data.shape[0] if len(data.shape) >= 4 else 1 self.is_5d = len(data.shape) == 5 else: self.npz_data = None data = np.load(data_path, allow_pickle=True) self.num_samples = data.shape[0] if len(data.shape) >= 4 else 1 self.is_5d = len(data.shape) == 5 super().__init__(data_source, num_frames, image_size, transform, normalize) def _get_video_paths(self) -> Sequence[Union[Path, int]]: # For numpy-backed datasets we represent entries as integer indices return list(range(self.num_samples)) def _load_video(self, idx: int) -> torch.Tensor: if self.npz_data is not None: data = self.npz_data[self.key] else: data = np.load(str(self.data_source), allow_pickle=True) if self.is_5d: video = data[idx] else: video = data if isinstance(video, np.ndarray): video = torch.from_numpy(video).float() else: video = torch.tensor(video).float() return video
[docs] class RLEnvironmentDataset(VideoDatasetBase): """Dataset for RL environment recordings. Loads trajectories stored as: - .npz files with 'observations' and 'actions' keys - Directory with episode folders Usage: dataset = RLEnvironmentDataset( data_source="/path/to/rl_episodes", num_frames=16, image_size=64 ) """ def __init__( self, data_source: Union[str, Path], num_frames: int = 16, image_size: int = 64, transform: Optional[Callable] = None, normalize: bool = True, obs_key: str = "observations", ): self.obs_key = obs_key super().__init__(data_source, num_frames, image_size, transform, normalize) def _get_video_paths(self) -> Sequence[Union[Path, int]]: data_path = Path(cast(Union[str, Path], self.data_source)) if data_path.is_file() and data_path.suffix == ".npz": self.single_file = True return [data_path] episode_files = list(data_path.rglob("*.npz")) return sorted(episode_files) def _load_video(self, idx: int) -> torch.Tensor: if hasattr(self, "single_file") and self.single_file: data = np.load(str(cast(Path, self.video_paths[idx])), allow_pickle=True) observations = data[self.obs_key] else: data = np.load(str(cast(Path, self.video_paths[idx])), allow_pickle=True) observations = data[self.obs_key] if isinstance(observations, dict): if "image" in observations: observations = observations["image"] elif "pixels" in observations: observations = observations["pixels"] else: observations = list(observations.values())[0] observations = np.array(observations) if observations.ndim == 3: observations = np.expand_dims(observations, axis=-1) if observations.shape[-1] in [1, 3] and observations.ndim == 4: observations = np.transpose(observations, (0, 3, 1, 2)) total_frames = observations.shape[0] if total_frames >= self.num_frames: indices = np.linspace(0, total_frames - 1, self.num_frames).astype(int) observations = observations[indices] else: padding = np.tile( observations[-1:], (self.num_frames - total_frames, 1, 1, 1) ) observations = np.concatenate([observations, padding], axis=0) processed = [] for frame in observations: if frame.shape[0] == 1: frame = frame[0] if frame.shape[-1] == 3: frame = np.transpose(frame, (1, 2, 0)) img = Image.fromarray(frame.astype(np.uint8)) img = img.resize((self.image_size, self.image_size)) processed.append(np.array(img)) return torch.from_numpy(np.stack(processed)).float()
[docs] class HDF5Dataset(VideoDatasetBase): """Dataset that loads videos from HDF5 files. Supports pre-processed video datasets stored in HDF5 format. Expected structure: HDF5 file with 'videos' dataset of shape (N, T, H, W, C) or (N, T, C, H, W). Usage: dataset = HDF5Dataset( data_source="/path/to/videos.h5", num_frames=16, image_size=64 ) """ def __init__( self, data_source: Union[str, Path], num_frames: int = 16, image_size: int = 64, transform: Optional[Callable] = None, normalize: bool = True, key: str = "videos", memmap: bool = False, ): self.key = key self.memmap = memmap self._h5_file = None data_path = Path(data_source) if not data_path.exists(): raise FileNotFoundError(f"HDF5 file not found: {data_path}") with h5py.File(data_path, "r") as f: if key not in f: available_keys = list(f.keys()) raise KeyError( f"Key '{key}' not found in HDF5. Available: {available_keys}" ) data = cast(Any, f[key]) # mypy doesn't understand h5py dataset types; treat as Any for shape access if len(data.shape) == 5: self.num_samples = int(data.shape[0]) self.video_length = int(data.shape[1]) self.raw_height = int(data.shape[2]) self.raw_width = int(data.shape[3]) self.channels = int(data.shape[4]) self.data_layout = "NCTHW" if data.shape[-1] <= 4 else "NTHWC" elif len(data.shape) == 4: self.num_samples = int(data.shape[0]) self.video_length = int(data.shape[1]) self.raw_height = int(data.shape[2]) self.raw_width = int(data.shape[3]) self.channels = 1 self.data_layout = "NTHW" else: raise ValueError(f"Unexpected data shape: {data.shape}") super().__init__(data_source, num_frames, image_size, transform, normalize) def _get_video_paths(self) -> Sequence[Union[Path, int]]: return list(range(self.num_samples)) def _open_h5(self): if self._h5_file is None: self._h5_file = h5py.File(self.data_source, "r" if self.memmap else "r") return self._h5_file def _load_video(self, idx: int) -> torch.Tensor: f = self._open_h5() data = cast(Any, f[self.key]) if self.memmap: video = data[idx] else: video = data[idx][:] if not isinstance(video, np.ndarray): video = np.array(video) if self.data_layout == "NCTHW": video = np.transpose(video, (0, 2, 3, 1)) elif self.data_layout == "NTHWC": video = np.transpose(video, (0, 1, 3, 2)) if video.shape[-1] == 1: video = np.repeat(video, 3, axis=-1) total_frames = video.shape[0] if total_frames >= self.num_frames: indices = np.linspace(0, total_frames - 1, self.num_frames).astype(int) video = video[indices] else: padding = np.tile(video[-1:], (self.num_frames - total_frames, 1, 1, 1)) video = np.concatenate([video, padding], axis=0) if video.shape[1] != self.image_size or video.shape[2] != self.image_size: processed = [] for frame in video: if frame.max() <= 1.0 and frame.dtype != np.uint8: frame = (frame * 255).astype(np.uint8) img = Image.fromarray(frame) img = img.resize( (self.image_size, self.image_size), Image.Resampling.BILINEAR ) processed.append(np.array(img)) video = np.stack(processed) else: if video.max() <= 1.0: video = (video * 255).astype(np.uint8) return torch.from_numpy(video).float() def __len__(self) -> int: return self.num_samples def __del__(self): if self._h5_file is not None: self._h5_file.close()
[docs] def create_video_dataloader( dataset_type: str, data_source: Union[str, Path, List[str]], num_frames: int = 16, image_size: int = 64, batch_size: int = 4, num_workers: int = 4, shuffle: bool = True, pin_memory: bool = True, **kwargs, ) -> Tuple[Dataset, DataLoader]: """Factory function to create video dataloaders. Args: dataset_type: Type of dataset ("video_folder", "image_folder", "numpy", "rl") data_source: Path or list of paths to data num_frames: Number of frames per video image_size: Target image size (height and width) batch_size: Batch size for dataloader num_workers: Number of workers for data loading shuffle: Whether to shuffle data pin_memory: Whether to pin memory for faster GPU transfer **kwargs: Additional arguments for specific dataset types Returns: Tuple of (dataset, dataloader) Usage: dataset, loader = create_video_dataloader( dataset_type="video_folder", data_source="/path/to/videos", num_frames=16, image_size=64, batch_size=4 ) """ dataset_classes = { "video_folder": VideoFolderDataset, "image_folder": ImageFolderDataset, "numpy": NumPyDataset, "rl": RLEnvironmentDataset, } if dataset_type not in dataset_classes: raise ValueError( f"Unknown dataset type: {dataset_type}. Available: {list(dataset_classes.keys())}" ) dataset_class = dataset_classes[dataset_type] dataset = dataset_class( data_source=data_source, num_frames=num_frames, image_size=image_size, **kwargs, ) loader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle and len(dataset) > 0, num_workers=num_workers, pin_memory=pin_memory, drop_last=shuffle and len(dataset) >= batch_size, ) logger.info( f"Created {dataset_type} dataloader with {len(dataset)} samples, batch_size={batch_size}" ) return dataset, loader
[docs] @dataclass class VideoDatasetConfig(DatasetConfig): """Configuration for video datasets.""" dataset_type: str = "video_folder" data_source: str = "" extensions: Tuple[str, ...] = (".mp4", ".avi", ".mkv") recursive: bool = True obs_key: str = "observations"
[docs] def create_video_dataset_from_config( config: VideoDatasetConfig, ) -> Tuple[Dataset, DataLoader]: """Create video dataset and dataloader from config.""" return create_video_dataloader( dataset_type=config.dataset_type, data_source=config.data_source, num_frames=config.num_frames, image_size=config.image_size, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=config.shuffle, pin_memory=config.pin_memory, extensions=config.extensions if hasattr(config, "extensions") else (".mp4", ".avi", ".mkv"), recursive=config.recursive if hasattr(config, "recursive") else True, obs_key=config.obs_key if hasattr(config, "obs_key") else "observations", )