Source code for world_models.training.train_jepa

import os

try:
    os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["SLURM_LOCALID"]
except Exception:
    pass

import copy
import logging
import sys
import yaml

import numpy as np

import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel

from world_models.masks.multiblock import MaskCollator as MBMaskCollator
from world_models.utils.utils import apply_masks
from world_models.utils.jepa_utils import init_distributed, AllReduce
from world_models.utils.jepa_utils import (
    CSVLogger,
    gpu_timer,
    grad_logger,
    AverageMeter,
)
from world_models.utils.jepa_utils import repeat_interleave_batch
from world_models.datasets.imagenet1k import make_imagenet1k, make_imagefolder
from world_models.datasets.cifar10 import make_cifar10
from world_models.helpers.jepa_helper import load_checkpoint, init_model, init_opt
from world_models.transforms.transforms import make_transforms
from world_models.configs.jepa_config import JEPAConfig

log_timings = True
log_freq = 10
checkpoint_freq = 50

_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


[docs] def main(args, resume_preempt=False): """Run JEPA training using a nested config dict or `JEPAConfig` instance. This entrypoint initializes distributed context, data pipeline, masking, models, optimizers/schedulers, checkpointing, and the full epoch loop. """ if isinstance(args, JEPAConfig): args = args.to_dict() # ----------------------------------------------------------------------- # # PASSED IN PARAMS FROM CONFIG FILE # ----------------------------------------------------------------------- # # -- META use_bfloat16 = args["meta"]["use_bfloat16"] model_name = args["meta"]["model_name"] load_model = args["meta"]["load_checkpoint"] or resume_preempt r_file = args["meta"]["read_checkpoint"] copy_data = args["meta"]["copy_data"] pred_depth = args["meta"]["pred_depth"] pred_emb_dim = args["meta"]["pred_emb_dim"] if not torch.cuda.is_available(): device = torch.device("cpu") else: device = torch.device("cuda:0") torch.cuda.set_device(device) # -- DATA use_gaussian_blur = args["data"]["use_gaussian_blur"] use_horizontal_flip = args["data"]["use_horizontal_flip"] use_color_distortion = args["data"]["use_color_distortion"] color_jitter = args["data"]["color_jitter_strength"] # -- batch_size = args["data"]["batch_size"] pin_mem = args["data"]["pin_mem"] num_workers = args["data"]["num_workers"] root_path = args["data"]["root_path"] image_folder = args["data"]["image_folder"] crop_size = args["data"]["crop_size"] crop_scale = args["data"]["crop_scale"] # -- # -- MASK allow_overlap = args["mask"][ "allow_overlap" ] # whether to allow overlap b/w context and target blocks patch_size = args["mask"]["patch_size"] # patch-size for model training num_enc_masks = args["mask"]["num_enc_masks"] # number of context blocks min_keep = args["mask"]["min_keep"] # min number of patches in context block enc_mask_scale = args["mask"]["enc_mask_scale"] # scale of context blocks num_pred_masks = args["mask"]["num_pred_masks"] # number of target blocks pred_mask_scale = args["mask"]["pred_mask_scale"] # scale of target blocks aspect_ratio = args["mask"]["aspect_ratio"] # aspect ratio of target blocks # -- # -- OPTIMIZATION ema = args["optimization"]["ema"] ipe_scale = args["optimization"]["ipe_scale"] # scheduler scale factor (def: 1.0) wd = float(args["optimization"]["weight_decay"]) final_wd = float(args["optimization"]["final_weight_decay"]) num_epochs = args["optimization"]["epochs"] warmup = args["optimization"]["warmup"] start_lr = args["optimization"]["start_lr"] lr = args["optimization"]["lr"] final_lr = args["optimization"]["final_lr"] # -- LOGGING folder = args["logging"]["folder"] tag = args["logging"]["write_tag"] os.makedirs(folder, exist_ok=True) # ensure output dir exists dump = os.path.join(folder, "params-ijepa.yaml") with open(dump, "w") as f: yaml.dump(args, f) # ----------------------------------------------------------------------- # try: mp.set_start_method("spawn") except Exception: pass # -- init torch distributed backend world_size, rank = init_distributed() logger.info(f"Initialized (rank/world-size) {rank}/{world_size}") if rank > 0: logger.setLevel(logging.ERROR) # -- log/checkpointing paths log_file = os.path.join(folder, f"{tag}_r{rank}.csv") save_path = os.path.join(folder, f"{tag}" + "-ep{epoch}.pth.tar") latest_path = os.path.join(folder, f"{tag}-latest.pth.tar") load_path = None if load_model: load_path = os.path.join(folder, r_file) if r_file is not None else latest_path # -- make csv_logger csv_logger = CSVLogger( log_file, ("%d", "epoch"), ("%d", "itr"), ("%.5f", "loss"), ("%.5f", "mask-A"), ("%.5f", "mask-B"), ("%d", "time (ms)"), ) # -- init model encoder, predictor = init_model( device=device, patch_size=patch_size, crop_size=crop_size, pred_depth=pred_depth, pred_emb_dim=pred_emb_dim, model_name=model_name, ) target_encoder = copy.deepcopy(encoder) # -- make data transforms mask_collator = MBMaskCollator( input_size=crop_size, patch_size=patch_size, pred_mask_scale=pred_mask_scale, enc_mask_scale=enc_mask_scale, aspect_ratio=aspect_ratio, nenc=num_enc_masks, npred=num_pred_masks, allow_overlap=allow_overlap, min_keep=min_keep, ) transform = make_transforms( crop_size=crop_size, crop_scale=crop_scale, gaussian_blur=use_gaussian_blur, horizontal_flip=use_horizontal_flip, color_distortion=use_color_distortion, color_jitter=color_jitter, ) # -- init data-loaders/samplers dataset_type = args["data"]["dataset"] val_split = args["data"]["val_split"] download = args["data"].get("download", False) if dataset_type.lower() == "imagenet": _, unsupervised_loader, unsupervised_sampler = make_imagenet1k( transform=transform, batch_size=batch_size, collator=mask_collator, pin_mem=pin_mem, training=True, num_workers=num_workers, world_size=world_size, rank=rank, root_path=root_path, image_folder=image_folder, copy_data=copy_data, drop_last=True, ) elif dataset_type.lower() == "cifar10": _, unsupervised_loader, unsupervised_sampler = make_cifar10( transform=transform, batch_size=batch_size, collator=mask_collator, pin_mem=pin_mem, num_workers=num_workers, world_size=world_size, rank=rank, root_path=root_path, drop_last=True, train=True, download=download, # pass through ) else: _, unsupervised_loader, unsupervised_sampler = make_imagefolder( transform=transform, batch_size=batch_size, collator=mask_collator, pin_mem=pin_mem, num_workers=num_workers, world_size=world_size, rank=rank, root_path=root_path, image_folder=image_folder, drop_last=True, val_split=val_split, ) ipe = len(unsupervised_loader) # -- init optimizer and scheduler optimizer, scaler, scheduler, wd_scheduler = init_opt( encoder=encoder, predictor=predictor, wd=wd, final_wd=final_wd, start_lr=start_lr, ref_lr=lr, final_lr=final_lr, iterations_per_epoch=ipe, warmup=warmup, num_epochs=num_epochs, ipe_scale=ipe_scale, use_bfloat16=use_bfloat16, ) is_distributed = ( torch.distributed.is_available() and torch.distributed.is_initialized() and world_size > 1 ) if is_distributed: encoder = DistributedDataParallel(encoder, static_graph=True) predictor = DistributedDataParallel(predictor, static_graph=True) target_encoder = DistributedDataParallel(target_encoder) # keep modules unwrapped when not distributed for p in target_encoder.parameters(): p.requires_grad = False # -- momentum schedule momentum_scheduler = ( ema[0] + i * (ema[1] - ema[0]) / (ipe * num_epochs * ipe_scale) for i in range(int(ipe * num_epochs * ipe_scale) + 1) ) start_epoch = 0 # -- load training checkpoint if load_model: encoder, predictor, target_encoder, optimizer, scaler, start_epoch = ( load_checkpoint( device=device, r_path=load_path, encoder=encoder, predictor=predictor, target_encoder=target_encoder, opt=optimizer, scaler=scaler, ) ) for _ in range(start_epoch * ipe): scheduler.step() wd_scheduler.step() next(momentum_scheduler) mask_collator.step() def save_checkpoint(epoch): save_dict = { "encoder": encoder.state_dict(), "predictor": predictor.state_dict(), "target_encoder": target_encoder.state_dict(), "opt": optimizer.state_dict(), "scaler": None if scaler is None else scaler.state_dict(), "epoch": epoch, "loss": loss_meter.avg, "batch_size": batch_size, "world_size": world_size, "lr": lr, } if rank == 0: torch.save(save_dict, latest_path) if (epoch + 1) % checkpoint_freq == 0: torch.save(save_dict, save_path.format(epoch=f"{epoch + 1}")) # -- TRAINING LOOP for epoch in range(start_epoch, num_epochs): logger.info("Epoch %d" % (epoch + 1)) # -- update distributed-data-loader epoch unsupervised_sampler.set_epoch(epoch) loss_meter = AverageMeter() maskA_meter = AverageMeter() maskB_meter = AverageMeter() time_meter = AverageMeter() for itr, (udata, masks_enc, masks_pred) in enumerate(unsupervised_loader): def load_imgs(): # -- unsupervised imgs imgs = udata[0].to(device, non_blocking=True) masks_1 = [u.to(device, non_blocking=True) for u in masks_enc] masks_2 = [u.to(device, non_blocking=True) for u in masks_pred] return (imgs, masks_1, masks_2) imgs, masks_enc, masks_pred = load_imgs() maskA_meter.update(len(masks_enc[0][0])) maskB_meter.update(len(masks_pred[0][0])) def train_step(): _new_lr = scheduler.step() _new_wd = wd_scheduler.step() # -- def forward_target(): with torch.no_grad(): h = target_encoder(imgs) h = F.layer_norm(h, (h.size(-1),)) # normalize over feature-dim B = len(h) # -- create targets (masked regions of h) h = apply_masks(h, masks_pred) h = repeat_interleave_batch(h, B, repeat=len(masks_enc)) return h def forward_context(): z = encoder(imgs, masks_enc) z = predictor(z, masks_enc, masks_pred) return z def loss_fn(z, h): loss = F.smooth_l1_loss(z, h) loss = AllReduce.apply(loss) return loss # Step 1. Forward with torch.cuda.amp.autocast( dtype=torch.bfloat16, enabled=use_bfloat16 ): h = forward_target() z = forward_context() loss = loss_fn(z, h) # Step 2. Backward & step if use_bfloat16: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() enc_for_log = encoder.module if is_distributed else encoder grad_stats = grad_logger(enc_for_log.named_parameters()) optimizer.zero_grad() # Step 3. momentum update of target encoder with torch.no_grad(): m = next(momentum_scheduler) for param_q, param_k in zip( encoder.parameters(), target_encoder.parameters() ): param_k.data.mul_(m).add_((1.0 - m) * param_q.detach().data) return (float(loss), _new_lr, _new_wd, grad_stats) (loss, _new_lr, _new_wd, grad_stats), etime = gpu_timer(train_step) loss_meter.update(loss) time_meter.update(etime) def log_stats(): csv_logger.log( epoch + 1, itr, loss, maskA_meter.val, maskB_meter.val, etime ) if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss): logger.info( "[%d, %5d] loss: %.3f " "masks: %.1f %.1f " "[wd: %.2e] [lr: %.2e] " "[mem: %.2e] " "(%.1f ms)" % ( epoch + 1, itr, loss_meter.avg, maskA_meter.avg, maskB_meter.avg, _new_wd, _new_lr, torch.cuda.max_memory_allocated() / 1024.0**2, time_meter.avg, ) ) if grad_stats is not None: logger.info( "[%d, %5d] grad_stats: [%.2e %.2e] (%.2e, %.2e)" % ( epoch + 1, itr, grad_stats.first_layer, grad_stats.last_layer, grad_stats.min, grad_stats.max, ) ) log_stats() assert not np.isnan(loss), "loss is nan" logger.info("avg. loss %.3f" % loss_meter.avg) save_checkpoint(epoch + 1)