Source code for world_models.utils.jepa_utils

import math
import torch
from logging import getLogger
import os
import torch.distributed as dist

logger = getLogger()


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    with torch.no_grad():
        l_ = norm_cdf((a - mean) / std)
        u_ = norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l_ - 1, 2 * u_ - 1)
        tensor.erfinv_()

        tensor.mul_(std * math.sqrt(2.0))
        tensor.add_(mean)

        tensor.clamp_(min=a, max=b)
    return tensor


[docs] def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0): """Initialize a tensor in-place from a truncated normal distribution. Values are sampled from `N(mean, std)` and clipped to `[a, b]`. """ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
[docs] def repeat_interleave_batch(x, B, repeat): """Repeat each batch chunk multiple times while preserving chunk ordering. Used in JEPA masking code to align context and target token batches. """ N = len(x) // B x = torch.cat( [ torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) for i in range(N) ], dim=0, ) return x
[docs] class WarmupCosineSchedule(object): """Learning-rate schedule with linear warmup followed by cosine decay. Updates optimizer parameter-group LRs on each call to `step()`. """ def __init__( self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0, ): self.optimizer = optimizer self.start_lr = start_lr self.ref_lr = ref_lr self.final_lr = final_lr self.warmup_steps = warmup_steps self.T_max = T_max - warmup_steps self._step = 0.0
[docs] def step(self): self._step += 1 if self._step < self.warmup_steps: progress = float(self._step) / float(max(1, self.warmup_steps)) new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) else: progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) new_lr = max( self.final_lr, self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)), ) for group in self.optimizer.param_groups: group["lr"] = new_lr return new_lr
[docs] class CosineWDSchedule(object): """Cosine scheduler for optimizer weight decay values. Skips parameter groups flagged with `WD_exclude` to keep bias/norm decay at zero. """ def __init__( self, optimizer, ref_wd, T_max, final_wd=0.0, ): self.optimizer = optimizer self.ref_wd = ref_wd self.final_wd = final_wd self.T_max = T_max self._step = 0.0
[docs] def step(self): self._step += 1 progress = self._step / self.T_max new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * ( 1.0 + math.cos(math.pi * progress) ) if self.final_wd <= self.ref_wd: new_wd = max(self.final_wd, new_wd) else: new_wd = min(self.final_wd, new_wd) for group in self.optimizer.param_groups: if ("WD_exclude" not in group) or not group["WD_exclude"]: group["weight_decay"] = new_wd return new_wd
[docs] def gpu_timer(closure, log_timings=True): """Measure CUDA execution time for a closure and return `(result, elapsed_ms)`. Falls back to `-1` elapsed time when CUDA timing is unavailable. """ log_timings = log_timings and torch.cuda.is_available() elapsed_time = -1.0 if log_timings: start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() result = closure() if log_timings: end.record() torch.cuda.synchronize() elapsed_time = start.elapsed_time(end) return result, elapsed_time
[docs] class CSVLogger(object): """Lightweight CSV logger with per-column printf-style formatting.""" def __init__(self, fname, *argv): self.fname = fname self.types = [] with open(self.fname, "+a") as f: for i, v in enumerate(argv, 1): self.types.append(v[0]) if i < len(argv): print(v[1], end=",", file=f) else: print(v[1], end="\n", file=f)
[docs] def log(self, *argv): with open(self.fname, "+a") as f: for i, tv in enumerate(zip(self.types, argv), 1): end = "," if i < len(argv) else "\n" print(tv[0] % tv[1], end=end, file=f)
[docs] class AverageMeter(object): """Track running statistics (`val`, `avg`, `min`, `max`, `sum`, `count`) for metrics.""" def __init__(self): self.reset()
[docs] def reset(self): self.val = 0 self.avg = 0 self.max = float("-inf") self.min = float("inf") self.sum = 0 self.count = 0
[docs] def update(self, val, n=1): self.val = val try: self.max = max(val, self.max) self.min = min(val, self.min) except Exception: pass self.sum += val * n self.count += n self.avg = self.sum / self.count
[docs] def grad_logger(named_params): """Aggregate gradient norm statistics over model parameters for logging. Also exposes first/last qkv-layer gradient norms when available. """ stats = AverageMeter() stats.first_layer = None stats.last_layer = None for n, p in named_params: if (p.grad is not None) and not (n.endswith(".bias") or len(p.shape) == 1): grad_norm = float(torch.norm(p.grad.data)) stats.update(grad_norm) if "qkv" in n: stats.last_layer = grad_norm if stats.first_layer is None: stats.first_layer = grad_norm if stats.first_layer is None or stats.last_layer is None: stats.first_layer = 0.0 stats.last_layer = 0.0 return stats
[docs] def init_distributed(port=40112, rank_and_world_size=(None, None)): """Initialize torch distributed process groups when environment supports it. Returns `(world_size, rank)` and gracefully falls back to single-process mode. """ if dist.is_available() and dist.is_initialized(): return dist.get_world_size(), dist.get_rank() rank, world_size = rank_and_world_size os.environ["MASTER_ADDR"] = "localhost" if (rank is None) or (world_size is None): try: world_size = int(os.environ["SLURM_NTASKS"]) rank = int(os.environ["SLURM_PROCID"]) os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"] except Exception: logger.info("SLURM vars not set (distributed training not available)") world_size, rank = 1, 0 return world_size, rank try: os.environ["MASTER_PORT"] = str(port) torch.distributed.init_process_group( backend="nccl", world_size=world_size, rank=rank ) except Exception as e: world_size, rank = 1, 0 logger.info(f"distributed training not available {e}") return world_size, rank
[docs] class AllGather(torch.autograd.Function): """Autograd-aware all-gather operation across distributed workers. Forward concatenates worker tensors; backward reduces and slices gradients. """
[docs] @staticmethod def forward(ctx, x): if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: x = x.contiguous() outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] dist.all_gather(outputs, x) return torch.cat(outputs, dim=0) return x
[docs] @staticmethod def backward(ctx, grads): if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) grads = grads.contiguous() dist.all_reduce(grads) return grads[s:e] return grads
[docs] class AllReduceSum(torch.autograd.Function): """Autograd function that sums tensors across distributed workers in forward pass."""
[docs] @staticmethod def forward(ctx, x): if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: x = x.contiguous() dist.all_reduce(x) return x
[docs] @staticmethod def backward(ctx, grads): return grads
[docs] class AllReduce(torch.autograd.Function): """Autograd function that all-reduces and averages tensors across workers. Used to synchronize scalar losses for consistent distributed logging/training. """
[docs] @staticmethod def forward(ctx, x): if ( dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1) ): x = x.contiguous() / dist.get_world_size() dist.all_reduce(x) return x
[docs] @staticmethod def backward(ctx, grads): return grads