"""User-facing convenience APIs for TorchWM.
The lower-level modules remain available for research workflows, but this module
collects the common discovery and construction paths behind small, predictable
factory functions.
"""
from __future__ import annotations
from dataclasses import asdict, is_dataclass, replace
from importlib import import_module
from inspect import signature
from typing import Any, Callable, NamedTuple, cast
[docs]
class ModelSpec(NamedTuple):
"""Metadata describing a model available through :func:`create_model`."""
name: str
import_path: str
config_path: str | None = None
description: str = ""
aliases: tuple[str, ...] = ()
[docs]
class EnvBackendSpec(NamedTuple):
"""Metadata describing an environment backend available through ``make_env``."""
name: str
factory_path: str
description: str = ""
aliases: tuple[str, ...] = ()
MODEL_SPECS: dict[str, ModelSpec] = {
"dreamer": ModelSpec(
name="dreamer",
import_path="world_models.models.dreamer:DreamerAgent",
config_path="world_models.configs.dreamer_config:DreamerConfig",
description="High-level Dreamer agent with train/evaluate helpers.",
aliases=("dreamerv1", "dreamerv2", "dreamer_agent"),
),
"planet": ModelSpec(
name="planet",
import_path="world_models.models.planet:Planet",
description="PlaNet agent and planner for image-based control.",
aliases=("pla_net",),
),
"jepa": ModelSpec(
name="jepa",
import_path="world_models.models.jepa_agent:JEPAAgent",
config_path="world_models.configs.jepa_config:JEPAConfig",
description="JEPA self-supervised visual representation trainer.",
aliases=("ijepa", "i-jepa"),
),
"iris": ModelSpec(
name="iris",
import_path="world_models.models.iris_agent:IRISAgent",
config_path="world_models.configs.iris_config:IRISConfig",
description="IRIS world model and actor-critic module.",
aliases=("iris_agent",),
),
"genie": ModelSpec(
name="genie",
import_path="world_models.models.genie:create_genie",
config_path="world_models.configs.genie_config:GenieConfig",
description="Genie generative interactive environment model.",
aliases=("genie_base",),
),
"genie-small": ModelSpec(
name="genie-small",
import_path="world_models.models.genie:create_genie_small",
config_path="world_models.configs.genie_config:GenieSmallConfig",
description="Smaller Genie variant for development and tests.",
aliases=("genie_small",),
),
"genie-large": ModelSpec(
name="genie-large",
import_path="world_models.models.genie:create_genie_large",
description="Large Genie variant.",
aliases=("genie_large",),
),
"modular-rssm": ModelSpec(
name="modular-rssm",
import_path="world_models.models.modular_rssm:create_modular_rssm",
description="Factory for a modular recurrent state-space model.",
aliases=("modular_rssm", "rssm"),
),
}
ENV_BACKEND_SPECS: dict[str, EnvBackendSpec] = {
"auto": EnvBackendSpec(
name="auto",
factory_path="world_models.envs:make_env",
description="Try TorchWM env factories and fall back to Gym.",
aliases=("default",),
),
"gym": EnvBackendSpec(
name="gym",
factory_path="world_models.envs:make_gym_env",
description="Gym/Gymnasium image environments.",
aliases=("gymnasium",),
),
"atari": EnvBackendSpec(
name="atari",
factory_path="world_models.envs:make_atari_env",
description="Atari environments through ALE.",
aliases=("ale",),
),
"mujoco": EnvBackendSpec(
name="mujoco",
factory_path="world_models.envs:make_mujoco_env",
description="MuJoCo physics environments.",
aliases=("mjcf", "native_mujoco"),
),
"robotics": EnvBackendSpec(
name="robotics",
factory_path="world_models.envs:make_robotics_env",
description="Gymnasium Robotics environments.",
aliases=("gymnasium_robotics",),
),
"brax": EnvBackendSpec(
name="brax",
factory_path="world_models.envs:make_brax_env",
description="JAX/Brax continuous-control environments.",
aliases=(),
),
"unity": EnvBackendSpec(
name="unity",
factory_path="world_models.envs:make_unity_mlagents_env",
description="Unity ML-Agents executables.",
aliases=("mlagents", "unity_mlagents"),
),
}
def _normalize(name: str) -> str:
return name.strip().lower().replace("_", "-")
def _alias_map(specs: dict[str, ModelSpec] | dict[str, EnvBackendSpec]) -> dict[str, str]:
aliases: dict[str, str] = {}
for canonical, spec in specs.items():
aliases[_normalize(canonical)] = canonical
for alias in spec.aliases:
aliases[_normalize(alias)] = canonical
return aliases
def _resolve_model_name(name: str) -> str:
aliases = _alias_map(MODEL_SPECS)
try:
return aliases[_normalize(name)]
except KeyError as exc:
available = ", ".join(list_models())
raise ValueError(f"Unknown model {name!r}. Available models: {available}") from exc
def _resolve_backend_name(name: str) -> str:
aliases = _alias_map(ENV_BACKEND_SPECS)
try:
return aliases[_normalize(name)]
except KeyError as exc:
available = ", ".join(list_env_backends())
raise ValueError(f"Unknown environment backend {name!r}. Available: {available}") from exc
def _load_object(import_path: str) -> Any:
module_name, attr = import_path.split(":", maxsplit=1)
module = import_module(module_name)
return getattr(module, attr)
def _config_to_dict(config: Any) -> dict[str, Any]:
if config is None:
return {}
if isinstance(config, dict):
return dict(config)
if is_dataclass(config):
return asdict(cast(Any, config))
return {
key: value
for key, value in vars(config).items()
if not key.startswith("_") and not callable(value)
}
def _config_fields(config: Any) -> set[str]:
if config is None:
return set()
if isinstance(config, dict):
return set(config)
if is_dataclass(config):
return set(config.__dataclass_fields__) # type: ignore[attr-defined]
return {key for key in vars(config) if not key.startswith("_")}
def _split_config_and_constructor_overrides(
config: Any, overrides: dict[str, Any]
) -> tuple[dict[str, Any], dict[str, Any]]:
config_keys = _config_fields(config)
config_overrides = {key: value for key, value in overrides.items() if key in config_keys}
constructor_overrides = {
key: value for key, value in overrides.items() if key not in config_keys
}
return config_overrides, constructor_overrides
def _apply_overrides(config: Any, overrides: dict[str, Any]) -> Any:
if not overrides:
return config
if isinstance(config, dict):
updated = dict(config)
updated.update(overrides)
return updated
if is_dataclass(config):
valid = set(config.__dataclass_fields__) # type: ignore[attr-defined]
invalid = sorted(set(overrides) - valid)
if invalid:
raise ValueError(f"Invalid config override(s): {', '.join(invalid)}")
return replace(cast(Any, config), **overrides)
for key, value in overrides.items():
if not hasattr(config, key):
raise ValueError(f"Invalid config override: {key}")
setattr(config, key, value)
return config
def _supported_kwargs(factory: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]:
params = signature(factory).parameters
if any(param.kind == param.VAR_KEYWORD for param in params.values()):
return dict(kwargs)
supported = {name for name in params if name != "self"}
return {key: value for key, value in kwargs.items() if key in supported}
def _call_with_supported_kwargs(factory: Callable[..., Any], kwargs: dict[str, Any]) -> Any:
params = signature(factory).parameters
accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in params.values())
if accepts_kwargs:
return factory(**kwargs)
supported = {name for name in params if name != "self"}
filtered = {key: value for key, value in kwargs.items() if key in supported}
ignored = sorted(set(kwargs) - supported)
if ignored:
ignored_list = ", ".join(ignored)
raise ValueError(
f"Unsupported argument(s) for {factory.__name__}: {ignored_list}"
)
return factory(**filtered)
[docs]
def list_models() -> list[str]:
"""Return canonical model names accepted by :func:`create_model`."""
return sorted(MODEL_SPECS)
[docs]
def get_model_spec(name: str) -> ModelSpec:
"""Return metadata for a model name or alias."""
return MODEL_SPECS[_resolve_model_name(name)]
[docs]
def list_env_backends() -> list[str]:
"""Return canonical backend names accepted by :func:`make_env`."""
return sorted(ENV_BACKEND_SPECS)
[docs]
def get_env_backend_spec(name: str) -> EnvBackendSpec:
"""Return metadata for an environment backend name or alias."""
return ENV_BACKEND_SPECS[_resolve_backend_name(name)]
[docs]
def create_config(model: str, **overrides: Any) -> Any:
"""Create the default config object for ``model`` and apply overrides.
Examples:
>>> cfg = create_config("dreamer", env="walker-walk", seed=7)
>>> cfg.env
'walker-walk'
"""
spec = get_model_spec(model)
if spec.config_path is None:
if overrides:
raise ValueError(f"Model {spec.name!r} does not define a config object")
return None
config_cls = _load_object(spec.config_path)
config = config_cls()
return _apply_overrides(config, overrides)
[docs]
def create_model(model: str, config: Any | None = None, **overrides: Any) -> Any:
"""Instantiate a model or agent from a simple string name.
``config`` is optional for models that define a config class. Keyword
overrides are applied to the config when possible, otherwise they are passed
directly to the underlying constructor/factory.
Examples:
>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000)
>>> genie = create_model("genie-small", image_size=32)
"""
spec = get_model_spec(model)
factory = _load_object(spec.import_path)
if spec.config_path is not None:
if config is None:
config = create_config(spec.name)
config_overrides, constructor_overrides = _split_config_and_constructor_overrides(
config, overrides
)
config = _apply_overrides(config, config_overrides)
if spec.name in {"genie", "genie-small", "genie-large"}:
kwargs = _supported_kwargs(factory, _config_to_dict(config))
kwargs.update(constructor_overrides)
return _call_with_supported_kwargs(factory, kwargs)
return factory(config, **constructor_overrides)
kwargs = _config_to_dict(config)
kwargs.update(overrides)
return _call_with_supported_kwargs(factory, kwargs)
[docs]
def make_env(env_id: str, backend: str = "auto", **kwargs: Any) -> Any:
"""Create an environment with a consistent TorchWM entry point.
Args:
env_id: Environment id, XML path, Unity executable path, or backend-specific id.
backend: One of :func:`list_env_backends`; ``"auto"`` tries TorchWM's
compatibility helper.
**kwargs: Backend-specific options.
"""
spec = get_env_backend_spec(backend)
factory = _load_object(spec.factory_path)
if spec.name == "auto":
kwargs.setdefault("backend", backend)
return factory(env_id, **kwargs)
[docs]
def list_envs(model: str | None = None) -> list[str] | dict[str, list[str]]:
"""List known environment ids, optionally filtered by model family."""
from world_models.catalog import ENVIRONMENTS_BY_MODEL
if model is None:
return {key: list(value) for key, value in ENVIRONMENTS_BY_MODEL.items()}
canonical = _normalize(model).replace("-", "")
try:
return list(ENVIRONMENTS_BY_MODEL[canonical])
except KeyError as exc:
available = ", ".join(sorted(ENVIRONMENTS_BY_MODEL))
raise ValueError(f"Unknown model environment catalog {model!r}: {available}") from exc
__all__ = [
"EnvBackendSpec",
"ModelSpec",
"MODEL_SPECS",
"ENV_BACKEND_SPECS",
"create_config",
"create_model",
"get_env_backend_spec",
"get_model_spec",
"list_env_backends",
"list_envs",
"list_models",
"make_env",
]