import os
import subprocess
import time
import numpy as np
from logging import getLogger
import torch
import torchvision
from torch.utils.data import random_split
logger = getLogger()
[docs]
def make_imagenet1k(
transform,
batch_size,
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
copy_data=False,
drop_last=True,
subset_file=None,
):
"""Build an ImageNet-1K dataset and dataloader with distributed sampling support.
This helper optionally restricts data to a subset file and returns the
`(dataset, dataloader, sampler)` tuple used by training scripts.
"""
dataset = ImageNet(
root=root_path,
image_folder=image_folder,
transform=transform,
train=training,
copy_data=copy_data,
index_targets=False,
)
if subset_file is not None:
dataset = ImageNetSubset(dataset, subset_file)
logger.info("ImageNet dataset created")
dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dataset, num_replicas=world_size, rank=rank
)
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
sampler=dist_sampler,
batch_size=batch_size,
drop_last=drop_last,
pin_memory=pin_mem,
num_workers=num_workers,
persistent_workers=False,
)
logger.info("ImageNet unsupervised data loader created")
return dataset, data_loader, dist_sampler
[docs]
class ImageNet(torchvision.datasets.ImageFolder):
"""ImageNet dataset wrapper with optional local copy/extract workflow.
The class extends `torchvision.datasets.ImageFolder` and can stage data
from shared storage into local scratch space for faster multi-process
training on cluster environments.
"""
def __init__(
self,
root,
image_folder="imagenet_full_size/061417/",
tar_file="imagenet_full_size-061417.tar.gz",
transform=None,
train=True,
job_id=None,
local_rank=None,
copy_data=True,
index_targets=False,
):
"""
ImageNet
Dataset wrapper (can copy data locally to machine)
:param root: root network directory for ImageNet data
:param image_folder: path to images inside root network directory
:param tar_file: zipped image_folder inside root network directory
:param train: whether to load train data (or validation)
:param job_id: scheduler job-id used to create dir on local machine
:param copy_data: whether to copy data from network file locally
:param index_targets: whether to index the id of each labeled image
"""
suffix = "train/" if train else "val/"
data_path = None
if copy_data:
logger.info("copying data locally")
data_path = copy_imgnt_locally(
root=root,
suffix=suffix,
image_folder=image_folder,
tar_file=tar_file,
job_id=job_id,
local_rank=local_rank,
)
if (not copy_data) or (data_path is None):
data_path = os.path.join(root, image_folder, suffix)
logger.info(f"data-path {data_path}")
super(ImageNet, self).__init__(root=data_path, transform=transform)
logger.info("Initialized ImageNet")
if index_targets:
self.targets = []
for sample in self.samples:
self.targets.append(sample[1])
self.targets = np.array(self.targets)
self.samples = np.array(self.samples)
mint = None
self.target_indices = []
for t in range(len(self.classes)):
indices = np.squeeze(np.argwhere(self.targets == t)).tolist()
self.target_indices.append(indices)
mint = len(indices) if mint is None else min(mint, len(indices))
logger.debug(f"num-labeled target {t} {len(indices)}")
logger.info(f"min. labeled indices {mint}")
[docs]
class ImageNetSubset(object):
"""View over an `ImageNet` dataset filtered by an explicit image-id list.
The subset file contains target image names; only matching samples are
kept while preserving transforms and label mapping from the base dataset.
"""
def __init__(self, dataset, subset_file):
"""
ImageNetSubset
:param dataset: ImageNet dataset object
:param subset_file: '.txt' file containing IDs of IN1K images to keep
"""
self.dataset = dataset
self.subset_file = subset_file
self.filter_dataset_(subset_file)
[docs]
def filter_dataset_(self, subset_file):
"""Filter self.dataset to a subset"""
root = self.dataset.root
class_to_idx = self.dataset.class_to_idx
# -- update samples to subset of IN1k targets/samples
new_samples = []
logger.info(f"Using {subset_file}")
with open(subset_file, "r") as rfile:
for line in rfile:
class_name = line.split("_")[0]
target = class_to_idx[class_name]
img = line.split("\n")[0]
new_samples.append((os.path.join(root, class_name, img), target))
self.samples = new_samples
@property
def classes(self):
return self.dataset.classes
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
path, target = self.samples[index]
img = self.dataset.loader(path)
if self.dataset.transform is not None:
img = self.dataset.transform(img)
if self.dataset.target_transform is not None:
target = self.dataset.target_transform(target)
return img, target
[docs]
def copy_imgnt_locally(
root,
suffix,
image_folder="imagenet_full_size/061417/",
tar_file="imagenet_full_size-061417.tar.gz",
job_id=None,
local_rank=None,
):
"""Copy and extract ImageNet archives to per-job local scratch storage.
In SLURM environments this reduces network filesystem pressure by unpacking
once per job and synchronizing worker processes with a signal file.
"""
if job_id is None:
try:
job_id = os.environ["SLURM_JOBID"]
except Exception:
logger.info("No job-id, will load directly from network file")
return None
if local_rank is None:
try:
local_rank = int(os.environ["SLURM_LOCALID"])
except Exception:
logger.info("No job-id, will load directly from network file")
return None
source_file = os.path.join(root, tar_file)
target = f"/scratch/slurm_tmpdir/{job_id}/"
target_file = os.path.join(target, tar_file)
data_path = os.path.join(target, image_folder, suffix)
logger.info(f"{source_file}\n{target}\n{target_file}\n{data_path}")
tmp_sgnl_file = os.path.join(target, "copy_signal.txt")
if not os.path.exists(data_path):
if local_rank == 0:
commands = [["tar", "-xf", source_file, "-C", target]]
for cmnd in commands:
start_time = time.time()
logger.info(f"Executing {cmnd}")
subprocess.run(cmnd)
logger.info(f"Cmnd took {(time.time()-start_time)/60.} min.")
with open(tmp_sgnl_file, "+w") as f:
print("Done copying locally.", file=f)
else:
while not os.path.exists(tmp_sgnl_file):
time.sleep(60)
logger.info(f"{local_rank}: Checking {tmp_sgnl_file}")
return data_path
[docs]
def make_imagefolder(
transform,
batch_size,
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
drop_last=True,
val_split: float | None = None,
):
"""Create an ImageFolder dataset loader for custom folder-structured datasets.
Supports optional train/validation split and distributed sampling, making
it a drop-in replacement for ImageNet loaders in training scripts.
"""
dataset = torchvision.datasets.ImageFolder(
root=os.path.join(root_path, image_folder) if image_folder else root_path,
transform=transform,
)
if val_split:
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
dataset, _ = random_split(dataset, [train_size, val_size])
dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dataset, num_replicas=world_size, rank=rank
)
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
sampler=dist_sampler,
batch_size=batch_size,
drop_last=drop_last,
pin_memory=pin_mem,
num_workers=num_workers,
persistent_workers=False,
)
logger.info("ImageFolder data loader created")
return dataset, data_loader, dist_sampler