Source code for world_models.datasets.tinyworlds

"""
TinyWorlds Dataset Loaders

Loads pre-processed video datasets from HuggingFace for training Genie-style world models.
Based on: https://github.com/AlmondGod/tinyworlds

Available datasets:
- PICO_DOOM: Minimal Doom gameplay
- PONG: Classic Pong
- ZELDA: Zelda Ocarina of Time (2D)
- POLE_POSITION: Racing game
- SONIC: Sonic the Hedgehog
"""

import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from typing import Optional, Tuple, Dict, List, Any
from dataclasses import dataclass
import logging
import os
from pathlib import Path
import hashlib
import urllib.request
import zipfile
import shutil

logger = logging.getLogger(__name__)

try:
    import h5py
except ImportError:
    h5py = None
    logger.warning("h5py not installed. Install with: pip install h5py")

hf_hub_download: Any = None
list_repo_files: Any = None
try:
    from huggingface_hub import hf_hub_download, list_repo_files  # type: ignore

    HF_AVAILABLE = True
except ImportError:
    HF_AVAILABLE = False
    logger.warning(
        "huggingface_hub not installed. Install with: pip install huggingface_hub"
    )


[docs] @dataclass class TinyWorldsConfig: """Configuration for TinyWorlds datasets.""" dataset_name: str = "SONIC" num_frames: int = 16 image_size: int = 64 batch_size: int = 4 num_workers: int = 4 cache_dir: Optional[str] = None split: str = "train"
[docs] class TinyWorldsDataset(Dataset): """Dataset for TinyWorlds game video data. Loads pre-processed frames from HuggingFace datasets repository. """ DATASET_CONFIGS = { "PICO_DOOM": { "repo_id": "AlmondGod/tinyworlds", "filename": "picodoom_frames.h5", "description": "Minimal Doom gameplay", }, "PONG": { "repo_id": "AlmondGod/tinyworlds", "filename": "pong_frames.h5", "description": "Classic Pong", }, "ZELDA": { "repo_id": "AlmondGod/tinyworlds", "filename": "zelda_frames.h5", "description": "Zelda Ocarina of Time (2D)", }, "POLE_POSITION": { "repo_id": "AlmondGod/tinyworlds", "filename": "pole_position_frames.h5", "description": "Racing game", }, "SONIC": { "repo_id": "AlmondGod/tinyworlds", "filename": "sonic_frames.h5", "description": "Sonic the Hedgehog", }, } def __init__( self, dataset_name: str = "SONIC", num_frames: int = 16, image_size: int = 64, split: str = "train", cache_dir: Optional[str] = None, download: bool = True, data_file: Optional[str] = None, ): if dataset_name not in self.DATASET_CONFIGS: raise ValueError( f"Unknown dataset: {dataset_name}. Available: {list(self.DATASET_CONFIGS.keys())}" ) self.dataset_name = dataset_name self.config = self.DATASET_CONFIGS[dataset_name] self.num_frames = num_frames self.image_size = image_size self.split = split self.cache_dir = cache_dir or self._get_default_cache_dir() self.data_file = data_file self._data_file: Optional[Any] = None self.num_samples = 0 self.video_length = 0 if data_file: self._load_data(Path(data_file)) elif download: self._download_or_load_data() else: self._load_data() def _get_default_cache_dir(self) -> str: cache = os.path.expanduser("~/.cache/tinyworlds") os.makedirs(cache, exist_ok=True) return cache def _get_local_path(self) -> Path: return Path(self.cache_dir) / self.config["filename"] def _download_or_load_data(self): local_path = self._get_local_path() if local_path.exists(): logger.info(f"Found cached dataset at {local_path}") self._load_data() return if not HF_AVAILABLE: raise RuntimeError( "huggingface_hub not installed. Cannot download datasets.\n" "Install with: pip install huggingface_hub" ) logger.info(f"Downloading {self.dataset_name} dataset from HuggingFace...") logger.info(f"This may take several minutes depending on your connection.") try: downloaded_path = hf_hub_download( repo_id=self.config["repo_id"], filename=self.config["filename"], repo_type="dataset", cache_dir=self.cache_dir, ) downloaded_file = Path(downloaded_path) if not downloaded_file.exists(): raise FileNotFoundError(f"Downloaded file not found: {downloaded_path}") local_path = self._get_local_path() if downloaded_file != local_path and not local_path.exists(): shutil.copy2(downloaded_file, local_path) logger.info(f"Downloaded to: {local_path}") self._load_data() except Exception as e: raise RuntimeError( f"Failed to download {self.dataset_name}: {e}\n" "You can manually download from: " f"https://huggingface.co/datasets/{self.config['repo_id']}/tree/main" ) def _load_data(self, file_path: Optional[Path] = None): if h5py is None: raise RuntimeError( "h5py not installed. Cannot load dataset files. Install with: pip install h5py" ) local_path = file_path or self._get_local_path() if not local_path.exists(): raise FileNotFoundError(f"Dataset file not found: {local_path}") f = h5py.File(local_path, "r") self._data_file = f if "videos" in f: data = f["videos"] elif "frames" in f: data = f["frames"] else: available_keys = list(f.keys()) raise KeyError( f"No 'videos' or 'frames' key found. Available: {available_keys}" ) # h5py dataset objects are dynamically typed; access shape via getattr shape = getattr(data, "shape", None) if not isinstance(shape, tuple): raise ValueError( f"Unable to determine data shape for dataset: {self.dataset_name}" ) if len(shape) == 5: self.num_samples = shape[0] self.video_length = shape[1] self.raw_height = shape[2] self.raw_width = shape[3] self.channels = shape[4] self.data_layout = "NTHWC" elif len(shape) == 4: self.num_samples = shape[0] self.video_length = shape[1] self.raw_height = shape[2] self.raw_width = shape[3] self.channels = 1 self.data_layout = "NTHW" else: raise ValueError(f"Unexpected data shape: {shape}") logger.info( f"Loaded {self.dataset_name}: {self.num_samples} videos, " f"{self.video_length} frames each, {self.raw_height}x{self.raw_width}" ) def __len__(self) -> int: return self.num_samples def __getitem__(self, idx: int) -> torch.Tensor: f = self._data_file assert f is not None, "Data file is not loaded" data = f["videos" if "videos" in f else "frames"] video = data[idx][:] if not isinstance(video, np.ndarray): video = np.array(video) if self.data_layout == "NTHWC": pass elif self.data_layout == "NTHW": video = np.expand_dims(video, axis=-1) else: raise ValueError(f"Unknown data layout: {self.data_layout}") total_frames = video.shape[0] if total_frames >= self.num_frames: indices = np.linspace(0, total_frames - 1, self.num_frames).astype(int) video = video[indices] else: padding = np.tile(video[-1:], (self.num_frames - total_frames, 1, 1, 1)) video = np.concatenate([video, padding], axis=0) if video.max() <= 1.0: video = (video * 255).astype(np.uint8) else: video = video.astype(np.uint8) video = torch.from_numpy(video).float() video = video.permute(0, 3, 1, 2) if video.shape[1] == 1: video = video.expand(-1, 3, -1, -1) video = video.reshape( video.shape[0] * video.shape[1], video.shape[2], video.shape[3] ) C_val = video.shape[0] // self.num_frames video = video.reshape(C_val, self.num_frames, video.shape[1], video.shape[2]) if self.image_size is not None and ( video.shape[2] != self.image_size or video.shape[3] != self.image_size ): resize_transform = transforms.Resize((self.image_size, self.image_size)) video = resize_transform(video) video = video / 255.0 return video def __del__(self): if self._data_file is not None: self._data_file.close()
[docs] def get_info(self) -> Dict: """Return dataset information.""" return { "name": self.dataset_name, "description": self.config["description"], "num_samples": self.num_samples, "video_length": self.video_length, "raw_resolution": f"{self.raw_height}x{self.raw_width}", "channels": self.channels, }
[docs] class TinyWorldsDataLoader: """Factory class for creating TinyWorlds dataloaders.""" DATASET_NAMES = list(TinyWorldsDataset.DATASET_CONFIGS.keys())
[docs] @staticmethod def create_dataloader( dataset_name: str = "SONIC", num_frames: int = 16, image_size: int = 64, batch_size: int = 4, num_workers: int = 4, shuffle: bool = True, cache_dir: Optional[str] = None, download: bool = True, data_file: Optional[str] = None, ) -> Tuple[TinyWorldsDataset, DataLoader]: """Create a dataloader for TinyWorlds dataset. Args: dataset_name: Name of the game dataset (PICO_DOOM, PONG, ZELDA, POLE_POSITION, SONIC) num_frames: Number of frames per video sequence image_size: Target image size (will resize frames) batch_size: Batch size num_workers: Number of data loading workers shuffle: Whether to shuffle the data cache_dir: Directory to cache downloaded datasets download: Whether to download if not cached Returns: Tuple of (dataset, dataloader) Usage: dataset, loader = TinyWorldsDataLoader.create_dataloader( dataset_name="SONIC", num_frames=16, image_size=64, batch_size=4 ) """ dataset = TinyWorldsDataset( dataset_name=dataset_name, num_frames=num_frames, image_size=image_size, cache_dir=cache_dir, download=download, data_file=data_file, ) loader = DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False, drop_last=shuffle, ) logger.info( f"Created {dataset_name} dataloader: {len(dataset)} samples, " f"{len(loader)} batches, batch_size={batch_size}" ) return dataset, loader
[docs] @staticmethod def list_available_datasets() -> List[str]: """List all available dataset names.""" return TinyWorldsDataLoader.DATASET_NAMES
[docs] @staticmethod def get_dataset_info(dataset_name: str) -> Dict: """Get information about a specific dataset without downloading.""" if dataset_name not in TinyWorldsDataset.DATASET_CONFIGS: raise ValueError( f"Unknown dataset: {dataset_name}. " f"Available: {list(TinyWorldsDataset.DATASET_CONFIGS.keys())}" ) return TinyWorldsDataset.DATASET_CONFIGS[dataset_name]
[docs] def create_tinyworlds_dataloader( dataset_name: str = "SONIC", num_frames: int = 16, image_size: int = 64, batch_size: int = 4, num_workers: int = 4, shuffle: bool = True, cache_dir: Optional[str] = None, download: bool = True, data_file: Optional[str] = None, ) -> Tuple[TinyWorldsDataset, DataLoader]: """Factory function to create TinyWorlds dataloaders. Args: dataset_name: Name of the game dataset (PICO_DOOM, PONG, ZELDA, POLE_POSITION, SONIC) num_frames: Number of frames per video sequence image_size: Target image size batch_size: Batch size num_workers: Number of data loading workers shuffle: Whether to shuffle cache_dir: Cache directory for datasets download: Download if not cached Returns: Tuple of (dataset, dataloader) Usage: dataset, loader = create_tinyworlds_dataloader( dataset_name="SONIC", num_frames=16, batch_size=4 ) for batch in loader: # batch shape: (B, T, C, H, W) ... """ return TinyWorldsDataLoader.create_dataloader( dataset_name=dataset_name, num_frames=num_frames, image_size=image_size, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, cache_dir=cache_dir, download=download, data_file=data_file, )
[docs] def download_all_datasets(cache_dir: Optional[str] = None) -> Dict[str, Optional[str]]: """Download all available TinyWorlds datasets. Args: cache_dir: Directory to cache downloaded datasets Returns: Dictionary mapping dataset names to local file paths """ cache_dir = cache_dir or os.path.expanduser("~/.cache/tinyworlds") results: Dict[str, Optional[str]] = {} for dataset_name in TinyWorldsDataLoader.list_available_datasets(): logger.info(f"Downloading {dataset_name}...") try: dataset, _ = create_tinyworlds_dataloader( dataset_name=dataset_name, download=True, cache_dir=cache_dir, ) results[dataset_name] = str(dataset._get_local_path()) logger.info(f" -> {dataset_name} downloaded successfully") except Exception as e: logger.error(f" -> {dataset_name} failed: {e}") results[dataset_name] = None return results