Source code for world_models.utils.data_utils
from torch.utils.data import DataLoader, Dataset
from typing import Iterator, Optional
import multiprocessing
[docs]
def create_efficient_dataloader(
dataset: Dataset,
batch_size: int,
num_workers: Optional[int] = None,
pin_memory: bool = True,
prefetch_factor: int = 2,
persistent_workers: bool = True,
) -> DataLoader:
"""Create a memory-efficient and fast DataLoader."""
if num_workers is None:
num_workers = min(multiprocessing.cpu_count(), 8)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
drop_last=True,
)
[docs]
def prefetch_iterator(iterator: Iterator, buffer_size: int = 3):
"""Add prefetching to any iterator."""
from collections import deque
buffer: deque = deque()
def fill_buffer():
for item in iterator:
buffer.append(item)
if len(buffer) >= buffer_size:
break
fill_buffer()
while buffer:
yield buffer.popleft()
# Prefetch next item
try:
next_item = next(iterator)
buffer.append(next_item)
except StopIteration:
pass