Source code for world_models.datasets.cifar10

import torch
from torchvision.datasets import CIFAR10
from logging import getLogger

logger = getLogger()


[docs] def make_cifar10( transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False, # new ): """Create CIFAR-10 dataset objects and a distributed-capable dataloader. Returns the dataset, sampler, and loader configured with the provided transform/collator so callers can plug the loader directly into JEPA or diffusion training loops. """ dataset = CIFAR10( root=root_path, train=train, download=download, transform=transform, ) dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=dataset, num_replicas=world_size, rank=rank ) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, sampler=dist_sampler, batch_size=batch_size, drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, persistent_workers=False, ) logger.info("CIFAR10 data loader created") return dataset, data_loader, dist_sampler