Source code for world_models.datasets.cifar10

import torch
from torchvision.datasets import CIFAR10
from logging import getLogger

logger = getLogger()


[docs] def make_cifar10( transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False, ): """Create CIFAR-10 dataset and distributed dataloader. Factory function that creates a CIFAR-10 dataset with the provided transforms and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or diffusion training pipelines. Args: transform: Transforms to apply to images (e.g., RandomCrop, ColorJitter). batch_size (int): Number of samples per batch. collator (callable, optional): Custom collate function for batching (e.g., mask collator for JEPA). pin_mem (bool): Whether to pin memory for faster GPU transfer (default: True). num_workers (int): Number of data loading workers (default: 8). world_size (int): Number of distributed processes (default: 1). rank (int): Rank of current process in distributed setting (default: 0). root_path (str, optional): Path to store/load CIFAR-10 data. drop_last (bool): Whether to drop incomplete final batch (default: True). train (bool): Whether to load train or test split (default: True). download (bool): Whether to download dataset if not present (default: False). Returns: tuple: (dataset, dataloader, sampler) - dataset: torchvision.datasets.CIFAR10 instance - dataloader: torch.utils.data.DataLoader with distributed sampling - sampler: torch.utils.data.distributed.DistributedSampler Example: >>> transform = make_transforms(crop_size=224) >>> dataset, loader, sampler = make_cifar10( ... transform=transform, ... batch_size=256, ... root_path="./data", ... download=True ... ) """ dataset = CIFAR10( root=root_path, train=train, download=download, transform=transform, ) dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=dataset, num_replicas=world_size, rank=rank ) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, sampler=dist_sampler, batch_size=batch_size, drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, persistent_workers=False, ) logger.info("CIFAR10 data loader created") return dataset, data_loader, dist_sampler