Source code for world_models.masks.random
from multiprocessing import Value
from logging import getLogger
import torch
logger = getLogger()
[docs]
class MaskCollator(object):
"""Generate random context/prediction patch splits for masked training.
A random permutation of patch indices is sampled per image; a configurable
fraction is assigned to context and the remainder to prediction targets.
"""
def __init__(self, ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16):
super(MaskCollator, self).__init__()
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.patch_size = patch_size
self.height, self.width = (
input_size[0] // patch_size,
input_size[1] // patch_size,
)
self.ratio = ratio
self._itr_counter = Value("i", -1)
[docs]
def step(self):
i = self._itr_counter
with i.get_lock():
i.value += 1
v = i.value
return v
def __call__(self, batch):
B = len(batch)
collated_batch = torch.utils.data.default_collate(batch)
seed = self.step()
g = torch.Generator()
g.manual_seed(seed)
ratio = self.ratio
ratio = ratio[0] + torch.rand(1, generator=g).item() * (ratio[1] - ratio[0])
num_patches = self.height * self.width
num_keep = int(num_patches * (1 - ratio))
collated_masks_pred, collated_masks_enc = [], []
for _ in range(B):
m = torch.randperm(num_patches)
collated_masks_enc.append(m[:num_keep])
collated_masks_pred.append(m[num_keep:])
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)
return collated_batch, collated_masks_enc, collated_masks_pred