import math
from multiprocessing import Value
from logging import getLogger
import torch
logger = getLogger()
[docs]
class MaskCollator(object):
"""Generate multi-block encoder and predictor masks for JEPA training.
For each sample, this collator samples predictor target blocks and
context encoder blocks (optionally non-overlapping), then returns masked
patch indices aligned across the batch.
"""
def __init__(
self,
input_size=(224, 224),
patch_size=16,
enc_mask_scale=(0.2, 0.8),
pred_mask_scale=(0.5, 1.0),
aspect_ratio=(0.3, 3.0),
nenc=1,
npred=2,
min_keep=4,
allow_overlap=False,
):
super(MaskCollator, self).__init__()
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.patch_size = patch_size
self.height, self.width = (
input_size[0] // patch_size,
input_size[1] // patch_size,
)
self.enc_mask_scale = enc_mask_scale
self.pred_mask_scale = pred_mask_scale
self.aspect_ratio = aspect_ratio
self.nenc = nenc
self.npred = npred
self.min_keep = min_keep
self.allow_overlap = allow_overlap
self._itr_counter = Value("i", -1)
[docs]
def step(self):
i = self._itr_counter
with i.get_lock():
i.value += 1
v = i.value
return v
def _sample_block_size(self, generator, scale, aspect_ratio_scale):
_rand = torch.rand(1, generator=generator).item()
min_s, max_s = scale
mask_scale = min_s + (max_s - min_s) * _rand
max_keep = int(self.height * self.width * mask_scale)
min_ar, max_ar = aspect_ratio_scale
aspect_ratio = min_ar + _rand * (max_ar - min_ar)
h = int(round(math.sqrt(max_keep * aspect_ratio)))
w = int(round(math.sqrt(max_keep / aspect_ratio)))
# ensure at least 1 and not larger than grid
h = max(1, min(h, self.height))
w = max(1, min(w, self.width))
return (h, w)
def _sample_block_mask(self, b_size, acceptable_regions=None):
h, w = b_size
def constrain_mask(mask_bool):
# If acceptable_regions provided, mask out invalid positions (in-place)
if acceptable_regions is None:
return
try:
# acceptable_regions expected as same HxW boolean mask
mask_bool &= acceptable_regions.bool()
except Exception:
pass
tries = 0
timeout = og_timeout = 20
valid_mask = False
while not valid_mask:
# allow placement anywhere such that top+h <= height
top = int(torch.randint(0, max(1, self.height - h + 1), (1,)).item())
left = int(torch.randint(0, max(1, self.width - w + 1), (1,)).item())
mask = torch.zeros((self.height, self.width), dtype=torch.int32)
mask[top : top + h, left : left + w] = 1
if acceptable_regions is not None:
constrain_mask(mask)
mask = torch.nonzero(mask.flatten())
valid_mask = len(mask) > self.min_keep
if not valid_mask:
timeout -= 1
if timeout == 0:
tries += 1
timeout = og_timeout
logger.warning(
f'Mask generator says: "Valid mask not found, decreasing acceptable-regions [{tries}]"'
)
mask = mask.squeeze(1)
mask_complement = torch.ones((self.height, self.width), dtype=torch.int32)
mask_complement[top : top + h, left : left + w] = 0
return mask, mask_complement
def __call__(self, batch):
B = len(batch)
collated_batch = torch.utils.data.default_collate(batch)
seed = self.step()
g = torch.Generator()
g.manual_seed(seed)
p_size = self._sample_block_size(
generator=g,
scale=self.pred_mask_scale,
aspect_ratio_scale=self.aspect_ratio,
)
e_size = self._sample_block_size(
generator=g, scale=self.enc_mask_scale, aspect_ratio_scale=(1.0, 1.0)
)
collated_masks_pred, collated_masks_enc = [], []
min_keep_pred = self.height * self.width
min_keep_enc = self.height * self.width
for _ in range(B):
masks_p, masks_C = [], []
for _ in range(self.npred):
mask, mask_C = self._sample_block_mask(p_size)
masks_p.append(mask)
masks_C.append(mask_C)
min_keep_pred = min(min_keep_pred, len(mask))
collated_masks_pred.append(masks_p)
acceptable_regions = masks_C
try:
if self.allow_overlap:
acceptable_regions = None
except Exception as e:
logger.warning(f"Encountered exception in mask-generator {e}")
masks_e = []
for _ in range(self.nenc):
mask, _ = self._sample_block_mask(
e_size, acceptable_regions=acceptable_regions
)
masks_e.append(mask)
min_keep_enc = min(min_keep_enc, len(mask))
collated_masks_enc.append(masks_e)
collated_masks_pred = [
[cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred
]
collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred)
collated_masks_enc = [
[cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc
]
collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc)
return collated_batch, collated_masks_enc, collated_masks_pred