Source code for world_models.masks.default

from logging import getLogger

import torch

logger = getLogger()


[docs] class DefaultCollator(object): """Simple collator that returns batch data and no masking metadata. This is used when training code expects the JEPA-style collator return shape `(batch, masks_enc, masks_pred)` but masking is disabled. """ def __call__(self, batch): collated_batch = torch.utils.data.default_collate(batch) return collated_batch, None, None