Source code for world_models.models.jepa_agent

import os
import datetime
from world_models.configs.jepa_config import JEPAConfig
from world_models.training.train_jepa import main as train_jepa_main


[docs] class JEPAAgent: """Convenience interface for configuring and launching JEPA training runs. Accepts a `JEPAConfig` plus keyword overrides, prepares output folders, and delegates execution to the JEPA training entrypoint. """ def __init__(self, config: JEPAConfig | None = None, **kwargs): self.cfg = config if config is not None else JEPAConfig() for key, val in kwargs.items(): if key == "logdir": self.cfg.folder = val elif hasattr(self.cfg, key): setattr(self.cfg, key, val) else: raise ValueError(f"Invalid argument: {key}") if not getattr(self.cfg, "write_tag", None): ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") self.cfg.write_tag = f"jepa_{ts}" os.makedirs(self.cfg.folder, exist_ok=True)
[docs] def train(self): train_jepa_main(self.cfg)