Source code for world_models.helpers.jepa_helper

import logging
import sys

import torch

import world_models.models.vit as vit
from world_models.utils.jepa_utils import WarmupCosineSchedule, CosineWDSchedule
from world_models.utils.jepa_utils import trunc_normal_

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


[docs] def load_checkpoint( device, r_path, encoder, predictor, target_encoder, opt, scaler, ): """Load JEPA training state from disk into model and optimizer objects. Restores encoder, predictor, optional target encoder, optimizer state, and optional AMP scaler, returning the resumed epoch for training restart. """ try: checkpoint = torch.load(r_path, map_location=torch.device("cpu")) epoch = checkpoint["epoch"] # -- loading encoder pretrained_dict = checkpoint["encoder"] msg = encoder.load_state_dict(pretrained_dict) logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}") # -- loading predictor pretrained_dict = checkpoint["predictor"] msg = predictor.load_state_dict(pretrained_dict) logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}") # -- loading target_encoder if target_encoder is not None: print(list(checkpoint.keys())) pretrained_dict = checkpoint["target_encoder"] msg = target_encoder.load_state_dict(pretrained_dict) logger.info(f"loaded pretrained encoder from epoch {epoch} with msg: {msg}") # -- loading optimizer opt.load_state_dict(checkpoint["opt"]) if scaler is not None: scaler.load_state_dict(checkpoint["scaler"]) logger.info(f"loaded optimizers from epoch {epoch}") logger.info(f"read-path: {r_path}") del checkpoint except Exception as e: logger.info(f"Encountered exception when loading checkpoint {e}") epoch = 0 return encoder, predictor, target_encoder, opt, scaler, epoch
[docs] def init_model( device, patch_size=16, model_name="vit_base", crop_size=224, pred_depth=6, pred_emb_dim=384, ): """Initialize JEPA encoder and predictor modules with ViT backbones. Applies truncated-normal parameter initialization, moves modules to the requested device, and returns `(encoder, predictor)`. """ encoder = vit.__dict__[model_name](img_size=[crop_size], patch_size=patch_size) predictor = vit.__dict__["vit_predictor"]( num_patches=encoder.patch_embed.num_patches, embed_dim=encoder.embed_dim, predictor_embed_dim=pred_emb_dim, depth=pred_depth, num_heads=encoder.num_heads, ) def init_weights(m): if isinstance(m, torch.nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: torch.nn.init.constant_(m.bias, 0) elif isinstance(m, torch.nn.LayerNorm): torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.weight, 1.0) for m in encoder.modules(): init_weights(m) for m in predictor.modules(): init_weights(m) encoder.to(device) predictor.to(device) logger.info(encoder) return encoder, predictor
[docs] def init_opt( encoder, predictor, iterations_per_epoch, start_lr, ref_lr, warmup, num_epochs, wd=1e-6, final_wd=1e-6, final_lr=0.0, use_bfloat16=False, ipe_scale=1.25, ): """Build optimizer, AMP scaler, LR scheduler, and weight-decay scheduler for JEPA. Parameters are grouped to exclude bias/norm tensors from weight decay, matching typical transformer training best practices. """ param_groups = [ { "params": ( p for n, p in encoder.named_parameters() if ("bias" not in n) and (len(p.shape) != 1) ) }, { "params": ( p for n, p in predictor.named_parameters() if ("bias" not in n) and (len(p.shape) != 1) ) }, { "params": ( p for n, p in encoder.named_parameters() if ("bias" in n) or (len(p.shape) == 1) ), "WD_exclude": True, "weight_decay": 0, }, { "params": ( p for n, p in predictor.named_parameters() if ("bias" in n) or (len(p.shape) == 1) ), "WD_exclude": True, "weight_decay": 0, }, ] logger.info("Using AdamW") optimizer = torch.optim.AdamW(param_groups) scheduler = WarmupCosineSchedule( optimizer, warmup_steps=int(warmup * iterations_per_epoch), start_lr=start_lr, ref_lr=ref_lr, final_lr=final_lr, T_max=int(ipe_scale * num_epochs * iterations_per_epoch), ) wd_scheduler = CosineWDSchedule( optimizer, ref_wd=wd, final_wd=final_wd, T_max=int(ipe_scale * num_epochs * iterations_per_epoch), ) scaler = torch.cuda.amp.GradScaler("cuda") if use_bfloat16 else None return optimizer, scaler, scheduler, wd_scheduler