Source code for world_models.masks.multiblock

import math

from multiprocessing import Value

from logging import getLogger

import torch

logger = getLogger()


[docs] class MaskCollator(object): """Generate multi-block encoder and predictor masks for JEPA training. For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch. """ def __init__( self, input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False, ): super(MaskCollator, self).__init__() if not isinstance(input_size, tuple): input_size = (input_size,) * 2 self.patch_size = patch_size self.height, self.width = ( input_size[0] // patch_size, input_size[1] // patch_size, ) self.enc_mask_scale = enc_mask_scale self.pred_mask_scale = pred_mask_scale self.aspect_ratio = aspect_ratio self.nenc = nenc self.npred = npred self.min_keep = min_keep self.allow_overlap = allow_overlap self._itr_counter = Value("i", -1)
[docs] def step(self): i = self._itr_counter with i.get_lock(): i.value += 1 v = i.value return v
def _sample_block_size(self, generator, scale, aspect_ratio_scale): _rand = torch.rand(1, generator=generator).item() min_s, max_s = scale mask_scale = min_s + (max_s - min_s) * _rand max_keep = int(self.height * self.width * mask_scale) min_ar, max_ar = aspect_ratio_scale aspect_ratio = min_ar + _rand * (max_ar - min_ar) h = int(round(math.sqrt(max_keep * aspect_ratio))) w = int(round(math.sqrt(max_keep / aspect_ratio))) # ensure at least 1 and not larger than grid h = max(1, min(h, self.height)) w = max(1, min(w, self.width)) return (h, w) def _sample_block_mask(self, b_size, acceptable_regions=None): h, w = b_size def constrain_mask(mask_bool): # If acceptable_regions provided, mask out invalid positions (in-place) if acceptable_regions is None: return try: # acceptable_regions expected as same HxW boolean mask mask_bool &= acceptable_regions.bool() except Exception: pass tries = 0 timeout = og_timeout = 20 valid_mask = False while not valid_mask: # allow placement anywhere such that top+h <= height top = int(torch.randint(0, max(1, self.height - h + 1), (1,)).item()) left = int(torch.randint(0, max(1, self.width - w + 1), (1,)).item()) mask = torch.zeros((self.height, self.width), dtype=torch.int32) mask[top : top + h, left : left + w] = 1 if acceptable_regions is not None: constrain_mask(mask) mask = torch.nonzero(mask.flatten()) valid_mask = len(mask) > self.min_keep if not valid_mask: timeout -= 1 if timeout == 0: tries += 1 timeout = og_timeout logger.warning( f'Mask generator says: "Valid mask not found, decreasing acceptable-regions [{tries}]"' ) mask = mask.squeeze(1) mask_complement = torch.ones((self.height, self.width), dtype=torch.int32) mask_complement[top : top + h, left : left + w] = 0 return mask, mask_complement def __call__(self, batch): B = len(batch) collated_batch = torch.utils.data.default_collate(batch) seed = self.step() g = torch.Generator() g.manual_seed(seed) p_size = self._sample_block_size( generator=g, scale=self.pred_mask_scale, aspect_ratio_scale=self.aspect_ratio, ) e_size = self._sample_block_size( generator=g, scale=self.enc_mask_scale, aspect_ratio_scale=(1.0, 1.0) ) collated_masks_pred, collated_masks_enc = [], [] min_keep_pred = self.height * self.width min_keep_enc = self.height * self.width for _ in range(B): masks_p, masks_C = [], [] for _ in range(self.npred): mask, mask_C = self._sample_block_mask(p_size) masks_p.append(mask) masks_C.append(mask_C) min_keep_pred = min(min_keep_pred, len(mask)) collated_masks_pred.append(masks_p) acceptable_regions = masks_C try: if self.allow_overlap: acceptable_regions = None except Exception as e: logger.warning(f"Encountered exception in mask-generator {e}") masks_e = [] for _ in range(self.nenc): mask, _ = self._sample_block_mask( e_size, acceptable_regions=acceptable_regions ) masks_e.append(mask) min_keep_enc = min(min_keep_enc, len(mask)) collated_masks_enc.append(masks_e) collated_masks_pred = [ [cm[:min_keep_pred] for cm in cm_list] for cm_list in collated_masks_pred ] collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) collated_masks_enc = [ [cm[:min_keep_enc] for cm in cm_list] for cm_list in collated_masks_enc ] collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) return collated_batch, collated_masks_enc, collated_masks_pred