from __future__ import annotations
import datetime
import os
from pathlib import Path
from typing import Any
from world_models.configs.jepa_config import JEPAConfig
from world_models.models.model_io import (
apply_config_overrides,
coerce_config,
resolve_pretrained_file,
)
from world_models.training.train_jepa import main as train_jepa_main
from world_models.export import ExportableAgentMixin
[docs]
class JEPAAgent(ExportableAgentMixin):
"""Convenience interface for configuring and launching JEPA training runs.
Accepts a `JEPAConfig` plus keyword overrides, prepares output folders,
and delegates execution to the JEPA training entrypoint.
"""
def __init__(self, config: JEPAConfig | None = None, **kwargs: Any) -> None:
self.cfg = coerce_config(JEPAConfig, config)
for key, val in kwargs.items():
if key == "logdir":
self.cfg.folder = val
elif hasattr(self.cfg, key):
setattr(self.cfg, key, val)
else:
raise ValueError(f"Invalid argument: {key}")
if not getattr(self.cfg, "write_tag", None):
ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
self.cfg.write_tag = f"jepa_{ts}"
os.makedirs(self.cfg.folder, exist_ok=True)
self.cfg.to_yaml(Path(self.cfg.folder) / "config.yaml")
[docs]
@classmethod
def from_config(
cls,
config: JEPAConfig | dict[str, Any] | str | Path | None = None,
**overrides: Any,
) -> "JEPAAgent":
"""Build a JEPA agent from a config object, dict, YAML file, or YAML string."""
return cls(apply_config_overrides(coerce_config(JEPAConfig, config), overrides))
[docs]
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | Path,
*,
config: JEPAConfig | dict[str, Any] | str | Path | None = None,
checkpoint_filename: str | None = None,
config_filename: str = "config.yaml",
repo_type: str | None = None,
revision: str | None = None,
**overrides: Any,
) -> "JEPAAgent":
"""Create a JEPA agent from local/HF Hub config and checkpoint metadata."""
if config is None:
config_path = resolve_pretrained_file(
pretrained_model_name_or_path,
(config_filename, "jepa_config.yaml", "config.yml"),
repo_type=repo_type,
revision=revision,
)
if config_path is None:
raise FileNotFoundError(
"No config was provided and no config YAML was found beside "
f"{pretrained_model_name_or_path!r}."
)
args = JEPAConfig.from_yaml(config_path)
else:
args = coerce_config(JEPAConfig, config)
checkpoint_candidates = (
(checkpoint_filename,)
if checkpoint_filename is not None
else ("jepa-latest.pth.tar", "checkpoint.pth.tar", "model.pt")
)
checkpoint_path = resolve_pretrained_file(
pretrained_model_name_or_path,
checkpoint_candidates,
repo_type=repo_type,
revision=revision,
)
if checkpoint_path is not None:
args.load_checkpoint = True
args.read_checkpoint = str(checkpoint_path)
return cls(apply_config_overrides(args, overrides))
[docs]
def parameter_count(self, trainable_only: bool = False) -> int:
"""JEPA models are constructed inside training, so no parameters are resident."""
return 0
[docs]
def summary(self) -> dict[str, Any]:
"""Return the configured JEPA run metadata."""
return {
"total_parameters": 0,
"trainable_parameters": 0,
"model_name": self.cfg.model_name,
"config": self.cfg.to_dict(),
}
[docs]
def train(self) -> None:
train_jepa_main(self.cfg)