Source code for world_models.configs.jepa_config

import os
from typing import Tuple, Dict, Any


[docs] class JEPAConfig: """ Minimal configuration container for JEPA training. Converts to the nested dict expected by `train_jepa.main`. """ def __init__(self): # meta self.use_bfloat16: bool = False self.model_name: str = "vit_base" self.load_checkpoint: bool = False self.read_checkpoint: str | None = None self.copy_data: bool = False self.pred_depth: int = 6 self.pred_emb_dim: int = 384 # data self.dataset: str = "imagenet" # "imagenet" or "imagefolder" self.val_split: float | None = ( None # optional fraction for val split when using imagefolder ) self.use_gaussian_blur: bool = True self.use_horizontal_flip: bool = True self.use_color_distortion: bool = True self.color_jitter_strength: float = 0.5 self.batch_size: int = 64 self.pin_mem: bool = True self.num_workers: int = 8 self.root_path: str = os.environ.get("IMAGENET_ROOT", "/data/imagenet") self.image_folder: str = "train" self.crop_size: int = 224 self.crop_scale: Tuple[float, float] = (0.67, 1.0) self.download: bool = False # allow CIFAR10 download if missing # mask self.allow_overlap: bool = False self.patch_size: int = 16 self.num_enc_masks: int = 1 self.min_keep: int = 4 self.enc_mask_scale: Tuple[float, float] = (0.15, 0.2) self.num_pred_masks: int = 1 self.pred_mask_scale: Tuple[float, float] = (0.15, 0.2) self.aspect_ratio: Tuple[float, float] = (0.75, 1.5) # optimization self.ema: Tuple[float, float] = (0.996, 1.0) self.ipe_scale: float = 1.0 self.weight_decay: float = 0.04 self.final_weight_decay: float = 0.4 self.epochs: int = 300 self.warmup: int = 40 self.start_lr: float = 1e-6 self.lr: float = 1.5e-4 self.final_lr: float = 1e-6 # logging self.folder: str = "results/jepa" self.write_tag: str = "jepa_run"
[docs] def to_dict(self) -> Dict[str, Dict[str, Any]]: return { "meta": { "use_bfloat16": self.use_bfloat16, "model_name": self.model_name, "load_checkpoint": self.load_checkpoint, "read_checkpoint": self.read_checkpoint, "copy_data": self.copy_data, "pred_depth": self.pred_depth, "pred_emb_dim": self.pred_emb_dim, }, "data": { "dataset": self.dataset, "val_split": self.val_split, "use_gaussian_blur": self.use_gaussian_blur, "use_horizontal_flip": self.use_horizontal_flip, "use_color_distortion": self.use_color_distortion, "color_jitter_strength": self.color_jitter_strength, "batch_size": self.batch_size, "pin_mem": self.pin_mem, "num_workers": self.num_workers, "root_path": self.root_path, "image_folder": self.image_folder, "crop_size": self.crop_size, "crop_scale": self.crop_scale, "download": self.download, # new }, "mask": { "allow_overlap": self.allow_overlap, "patch_size": self.patch_size, "num_enc_masks": self.num_enc_masks, "min_keep": self.min_keep, "enc_mask_scale": self.enc_mask_scale, "num_pred_masks": self.num_pred_masks, "pred_mask_scale": self.pred_mask_scale, "aspect_ratio": self.aspect_ratio, }, "optimization": { "ema": self.ema, "ipe_scale": self.ipe_scale, "weight_decay": self.weight_decay, "final_weight_decay": self.final_weight_decay, "epochs": self.epochs, "warmup": self.warmup, "start_lr": self.start_lr, "lr": self.lr, "final_lr": self.final_lr, }, "logging": { "folder": self.folder, "write_tag": self.write_tag, }, }