Source code for world_models.datasets.imagenet1k

import os
import subprocess
import time

import numpy as np

from logging import getLogger

import torch
import torchvision
from torch.utils.data import random_split

logger = getLogger()


[docs] def make_imagenet1k( transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None, ): """Build an ImageNet-1K dataset and dataloader with distributed sampling support. This helper optionally restricts data to a subset file and returns the `(dataset, dataloader, sampler)` tuple used by training scripts. """ dataset = ImageNet( root=root_path, image_folder=image_folder, transform=transform, train=training, copy_data=copy_data, index_targets=False, ) if subset_file is not None: dataset = ImageNetSubset(dataset, subset_file) logger.info("ImageNet dataset created") dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=dataset, num_replicas=world_size, rank=rank ) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, sampler=dist_sampler, batch_size=batch_size, drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, persistent_workers=False, ) logger.info("ImageNet unsupervised data loader created") return dataset, data_loader, dist_sampler
[docs] class ImageNet(torchvision.datasets.ImageFolder): """ImageNet dataset wrapper with optional local copy/extract workflow. The class extends `torchvision.datasets.ImageFolder` and can stage data from shared storage into local scratch space for faster multi-process training on cluster environments. """ def __init__( self, root, image_folder="imagenet_full_size/061417/", tar_file="imagenet_full_size-061417.tar.gz", transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False, ): """ ImageNet Dataset wrapper (can copy data locally to machine) :param root: root network directory for ImageNet data :param image_folder: path to images inside root network directory :param tar_file: zipped image_folder inside root network directory :param train: whether to load train data (or validation) :param job_id: scheduler job-id used to create dir on local machine :param copy_data: whether to copy data from network file locally :param index_targets: whether to index the id of each labeled image """ suffix = "train/" if train else "val/" data_path = None if copy_data: logger.info("copying data locally") data_path = copy_imgnt_locally( root=root, suffix=suffix, image_folder=image_folder, tar_file=tar_file, job_id=job_id, local_rank=local_rank, ) if (not copy_data) or (data_path is None): data_path = os.path.join(root, image_folder, suffix) logger.info(f"data-path {data_path}") super(ImageNet, self).__init__(root=data_path, transform=transform) logger.info("Initialized ImageNet") if index_targets: self.targets = [] for sample in self.samples: self.targets.append(sample[1]) self.targets = np.array(self.targets) self.samples = np.array(self.samples) mint = None self.target_indices = [] for t in range(len(self.classes)): indices = np.squeeze(np.argwhere(self.targets == t)).tolist() self.target_indices.append(indices) mint = len(indices) if mint is None else min(mint, len(indices)) logger.debug(f"num-labeled target {t} {len(indices)}") logger.info(f"min. labeled indices {mint}")
[docs] class ImageNetSubset(object): """View over an `ImageNet` dataset filtered by an explicit image-id list. The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset. """ def __init__(self, dataset, subset_file): """ ImageNetSubset :param dataset: ImageNet dataset object :param subset_file: '.txt' file containing IDs of IN1K images to keep """ self.dataset = dataset self.subset_file = subset_file self.filter_dataset_(subset_file)
[docs] def filter_dataset_(self, subset_file): """Filter self.dataset to a subset""" root = self.dataset.root class_to_idx = self.dataset.class_to_idx # -- update samples to subset of IN1k targets/samples new_samples = [] logger.info(f"Using {subset_file}") with open(subset_file, "r") as rfile: for line in rfile: class_name = line.split("_")[0] target = class_to_idx[class_name] img = line.split("\n")[0] new_samples.append((os.path.join(root, class_name, img), target)) self.samples = new_samples
@property def classes(self): return self.dataset.classes def __len__(self): return len(self.samples) def __getitem__(self, index): path, target = self.samples[index] img = self.dataset.loader(path) if self.dataset.transform is not None: img = self.dataset.transform(img) if self.dataset.target_transform is not None: target = self.dataset.target_transform(target) return img, target
[docs] def copy_imgnt_locally( root, suffix, image_folder="imagenet_full_size/061417/", tar_file="imagenet_full_size-061417.tar.gz", job_id=None, local_rank=None, ): """Copy and extract ImageNet archives to per-job local scratch storage. In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file. """ if job_id is None: try: job_id = os.environ["SLURM_JOBID"] except Exception: logger.info("No job-id, will load directly from network file") return None if local_rank is None: try: local_rank = int(os.environ["SLURM_LOCALID"]) except Exception: logger.info("No job-id, will load directly from network file") return None source_file = os.path.join(root, tar_file) target = f"/scratch/slurm_tmpdir/{job_id}/" target_file = os.path.join(target, tar_file) data_path = os.path.join(target, image_folder, suffix) logger.info(f"{source_file}\n{target}\n{target_file}\n{data_path}") tmp_sgnl_file = os.path.join(target, "copy_signal.txt") if not os.path.exists(data_path): if local_rank == 0: commands = [["tar", "-xf", source_file, "-C", target]] for cmnd in commands: start_time = time.time() logger.info(f"Executing {cmnd}") subprocess.run(cmnd) logger.info(f"Cmnd took {(time.time()-start_time)/60.} min.") with open(tmp_sgnl_file, "+w") as f: print("Done copying locally.", file=f) else: while not os.path.exists(tmp_sgnl_file): time.sleep(60) logger.info(f"{local_rank}: Checking {tmp_sgnl_file}") return data_path
[docs] def make_imagefolder( transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split: float | None = None, ): """Create an ImageFolder dataset loader for custom folder-structured datasets. Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts. """ dataset = torchvision.datasets.ImageFolder( root=os.path.join(root_path, image_folder) if image_folder else root_path, transform=transform, ) if val_split: val_size = int(len(dataset) * val_split) train_size = len(dataset) - val_size dataset, _ = random_split(dataset, [train_size, val_size]) dist_sampler = torch.utils.data.distributed.DistributedSampler( dataset=dataset, num_replicas=world_size, rank=rank ) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=collator, sampler=dist_sampler, batch_size=batch_size, drop_last=drop_last, pin_memory=pin_mem, num_workers=num_workers, persistent_workers=False, ) logger.info("ImageFolder data loader created") return dataset, data_loader, dist_sampler