Source code for world_models.api

"""User-facing convenience APIs for TorchWM.

The lower-level modules remain available for research workflows, but this module
collects the common discovery and construction paths behind small, predictable
factory functions.
"""

from __future__ import annotations

from dataclasses import asdict, is_dataclass, replace
from importlib import import_module
from inspect import signature
from typing import Any, Callable, NamedTuple, cast


[docs] class ModelSpec(NamedTuple): """Metadata describing a model available through :func:`create_model`.""" name: str import_path: str config_path: str | None = None description: str = "" aliases: tuple[str, ...] = ()
[docs] class EnvBackendSpec(NamedTuple): """Metadata describing an environment backend available through ``make_env``.""" name: str factory_path: str description: str = "" aliases: tuple[str, ...] = ()
MODEL_SPECS: dict[str, ModelSpec] = { "dreamer": ModelSpec( name="dreamer", import_path="world_models.models.dreamer:DreamerAgent", config_path="world_models.configs.dreamer_config:DreamerConfig", description="High-level Dreamer agent with train/evaluate helpers.", aliases=("dreamerv1", "dreamerv2", "dreamer_agent"), ), "planet": ModelSpec( name="planet", import_path="world_models.models.planet:Planet", description="PlaNet agent and planner for image-based control.", aliases=("pla_net",), ), "jepa": ModelSpec( name="jepa", import_path="world_models.models.jepa_agent:JEPAAgent", config_path="world_models.configs.jepa_config:JEPAConfig", description="JEPA self-supervised visual representation trainer.", aliases=("ijepa", "i-jepa"), ), "iris": ModelSpec( name="iris", import_path="world_models.models.iris_agent:IRISAgent", config_path="world_models.configs.iris_config:IRISConfig", description="IRIS world model and actor-critic module.", aliases=("iris_agent",), ), "genie": ModelSpec( name="genie", import_path="world_models.models.genie:create_genie", config_path="world_models.configs.genie_config:GenieConfig", description="Genie generative interactive environment model.", aliases=("genie_base",), ), "genie-small": ModelSpec( name="genie-small", import_path="world_models.models.genie:create_genie_small", config_path="world_models.configs.genie_config:GenieSmallConfig", description="Smaller Genie variant for development and tests.", aliases=("genie_small",), ), "genie-large": ModelSpec( name="genie-large", import_path="world_models.models.genie:create_genie_large", description="Large Genie variant.", aliases=("genie_large",), ), "modular-rssm": ModelSpec( name="modular-rssm", import_path="world_models.models.modular_rssm:create_modular_rssm", description="Factory for a modular recurrent state-space model.", aliases=("modular_rssm", "rssm"), ), } ENV_BACKEND_SPECS: dict[str, EnvBackendSpec] = { "auto": EnvBackendSpec( name="auto", factory_path="world_models.envs:make_env", description="Try TorchWM env factories and fall back to Gym.", aliases=("default",), ), "gym": EnvBackendSpec( name="gym", factory_path="world_models.envs:make_gym_env", description="Gym/Gymnasium image environments.", aliases=("gymnasium",), ), "atari": EnvBackendSpec( name="atari", factory_path="world_models.envs:make_atari_env", description="Atari environments through ALE.", aliases=("ale",), ), "mujoco": EnvBackendSpec( name="mujoco", factory_path="world_models.envs:make_mujoco_env", description="MuJoCo physics environments.", aliases=("mjcf", "native_mujoco"), ), "robotics": EnvBackendSpec( name="robotics", factory_path="world_models.envs:make_robotics_env", description="Gymnasium Robotics environments.", aliases=("gymnasium_robotics",), ), "brax": EnvBackendSpec( name="brax", factory_path="world_models.envs:make_brax_env", description="JAX/Brax continuous-control environments.", aliases=(), ), "unity": EnvBackendSpec( name="unity", factory_path="world_models.envs:make_unity_mlagents_env", description="Unity ML-Agents executables.", aliases=("mlagents", "unity_mlagents"), ), } def _normalize(name: str) -> str: return name.strip().lower().replace("_", "-") def _alias_map(specs: dict[str, ModelSpec] | dict[str, EnvBackendSpec]) -> dict[str, str]: aliases: dict[str, str] = {} for canonical, spec in specs.items(): aliases[_normalize(canonical)] = canonical for alias in spec.aliases: aliases[_normalize(alias)] = canonical return aliases def _resolve_model_name(name: str) -> str: aliases = _alias_map(MODEL_SPECS) try: return aliases[_normalize(name)] except KeyError as exc: available = ", ".join(list_models()) raise ValueError(f"Unknown model {name!r}. Available models: {available}") from exc def _resolve_backend_name(name: str) -> str: aliases = _alias_map(ENV_BACKEND_SPECS) try: return aliases[_normalize(name)] except KeyError as exc: available = ", ".join(list_env_backends()) raise ValueError(f"Unknown environment backend {name!r}. Available: {available}") from exc def _load_object(import_path: str) -> Any: module_name, attr = import_path.split(":", maxsplit=1) module = import_module(module_name) return getattr(module, attr) def _config_to_dict(config: Any) -> dict[str, Any]: if config is None: return {} if isinstance(config, dict): return dict(config) if is_dataclass(config): return asdict(cast(Any, config)) return { key: value for key, value in vars(config).items() if not key.startswith("_") and not callable(value) } def _config_fields(config: Any) -> set[str]: if config is None: return set() if isinstance(config, dict): return set(config) if is_dataclass(config): return set(config.__dataclass_fields__) # type: ignore[attr-defined] return {key for key in vars(config) if not key.startswith("_")} def _split_config_and_constructor_overrides( config: Any, overrides: dict[str, Any] ) -> tuple[dict[str, Any], dict[str, Any]]: config_keys = _config_fields(config) config_overrides = {key: value for key, value in overrides.items() if key in config_keys} constructor_overrides = { key: value for key, value in overrides.items() if key not in config_keys } return config_overrides, constructor_overrides def _apply_overrides(config: Any, overrides: dict[str, Any]) -> Any: if not overrides: return config if isinstance(config, dict): updated = dict(config) updated.update(overrides) return updated if is_dataclass(config): valid = set(config.__dataclass_fields__) # type: ignore[attr-defined] invalid = sorted(set(overrides) - valid) if invalid: raise ValueError(f"Invalid config override(s): {', '.join(invalid)}") return replace(cast(Any, config), **overrides) for key, value in overrides.items(): if not hasattr(config, key): raise ValueError(f"Invalid config override: {key}") setattr(config, key, value) return config def _supported_kwargs(factory: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]: params = signature(factory).parameters if any(param.kind == param.VAR_KEYWORD for param in params.values()): return dict(kwargs) supported = {name for name in params if name != "self"} return {key: value for key, value in kwargs.items() if key in supported} def _call_with_supported_kwargs(factory: Callable[..., Any], kwargs: dict[str, Any]) -> Any: params = signature(factory).parameters accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in params.values()) if accepts_kwargs: return factory(**kwargs) supported = {name for name in params if name != "self"} filtered = {key: value for key, value in kwargs.items() if key in supported} ignored = sorted(set(kwargs) - supported) if ignored: ignored_list = ", ".join(ignored) raise ValueError( f"Unsupported argument(s) for {factory.__name__}: {ignored_list}" ) return factory(**filtered)
[docs] def list_models() -> list[str]: """Return canonical model names accepted by :func:`create_model`.""" return sorted(MODEL_SPECS)
[docs] def get_model_spec(name: str) -> ModelSpec: """Return metadata for a model name or alias.""" return MODEL_SPECS[_resolve_model_name(name)]
[docs] def list_env_backends() -> list[str]: """Return canonical backend names accepted by :func:`make_env`.""" return sorted(ENV_BACKEND_SPECS)
[docs] def get_env_backend_spec(name: str) -> EnvBackendSpec: """Return metadata for an environment backend name or alias.""" return ENV_BACKEND_SPECS[_resolve_backend_name(name)]
[docs] def create_config(model: str, **overrides: Any) -> Any: """Create the default config object for ``model`` and apply overrides. Examples: >>> cfg = create_config("dreamer", env="walker-walk", seed=7) >>> cfg.env 'walker-walk' """ spec = get_model_spec(model) if spec.config_path is None: if overrides: raise ValueError(f"Model {spec.name!r} does not define a config object") return None config_cls = _load_object(spec.config_path) config = config_cls() return _apply_overrides(config, overrides)
[docs] def create_model(model: str, config: Any | None = None, **overrides: Any) -> Any: """Instantiate a model or agent from a simple string name. ``config`` is optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory. Examples: >>> agent = create_model("dreamer", env="walker-walk", total_steps=1000) >>> genie = create_model("genie-small", image_size=32) """ spec = get_model_spec(model) factory = _load_object(spec.import_path) if spec.config_path is not None: if config is None: config = create_config(spec.name) config_overrides, constructor_overrides = _split_config_and_constructor_overrides( config, overrides ) config = _apply_overrides(config, config_overrides) if spec.name in {"genie", "genie-small", "genie-large"}: kwargs = _supported_kwargs(factory, _config_to_dict(config)) kwargs.update(constructor_overrides) return _call_with_supported_kwargs(factory, kwargs) return factory(config, **constructor_overrides) kwargs = _config_to_dict(config) kwargs.update(overrides) return _call_with_supported_kwargs(factory, kwargs)
[docs] def make_env(env_id: str, backend: str = "auto", **kwargs: Any) -> Any: """Create an environment with a consistent TorchWM entry point. Args: env_id: Environment id, XML path, Unity executable path, or backend-specific id. backend: One of :func:`list_env_backends`; ``"auto"`` tries TorchWM's compatibility helper. **kwargs: Backend-specific options. """ spec = get_env_backend_spec(backend) factory = _load_object(spec.factory_path) if spec.name == "auto": kwargs.setdefault("backend", backend) return factory(env_id, **kwargs)
[docs] def list_envs(model: str | None = None) -> list[str] | dict[str, list[str]]: """List known environment ids, optionally filtered by model family.""" from world_models.catalog import ENVIRONMENTS_BY_MODEL if model is None: return {key: list(value) for key, value in ENVIRONMENTS_BY_MODEL.items()} canonical = _normalize(model).replace("-", "") try: return list(ENVIRONMENTS_BY_MODEL[canonical]) except KeyError as exc: available = ", ".join(sorted(ENVIRONMENTS_BY_MODEL)) raise ValueError(f"Unknown model environment catalog {model!r}: {available}") from exc
__all__ = [ "EnvBackendSpec", "ModelSpec", "MODEL_SPECS", "ENV_BACKEND_SPECS", "create_config", "create_model", "get_env_backend_spec", "get_model_spec", "list_env_backends", "list_envs", "list_models", "make_env", ]