Source code for world_models.envs.robotics_env
from __future__ import annotations
import importlib
import importlib.util
import sys
from typing import Any
import gymnasium as gym
from world_models.envs.gym_env import GymImageEnv
_GYMNASIUM_ROBOTICS_PACKAGE = "gymnasium_robotics"
_MOVED_MUJOCO_MESSAGE = "gymnasium-robotics"
def _registry_ids() -> set[str]:
return {str(env_id) for env_id in gym.envs.registry.keys()}
def _robotics_ids_from_registry() -> set[str]:
env_ids: set[str] = set()
for env_id, spec in gym.envs.registry.items():
entry_point = getattr(spec, "entry_point", "")
if isinstance(entry_point, str):
is_robotics_entry = entry_point.startswith(
f"{_GYMNASIUM_ROBOTICS_PACKAGE}."
)
else:
is_robotics_entry = getattr(entry_point, "__module__", "").startswith(
f"{_GYMNASIUM_ROBOTICS_PACKAGE}."
)
if is_robotics_entry:
env_ids.add(str(env_id))
return env_ids
[docs]
def is_moved_mujoco_error(exc: BaseException) -> bool:
"""Return whether Gymnasium reported the v2/v3 MuJoCo move."""
return _MOVED_MUJOCO_MESSAGE in str(exc).lower()
[docs]
def register_gymnasium_robotics_envs():
"""Import Gymnasium Robotics so its environments are registered with Gymnasium.
Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external
``gymnasium-robotics`` package. Current Gymnasium Robotics versions register
environments during import, while older plugin-style installations may rely
on ``gymnasium.register_envs``; this helper supports both paths.
"""
try:
package_spec = importlib.util.find_spec(_GYMNASIUM_ROBOTICS_PACKAGE)
except ValueError:
package_spec = None
if package_spec is None and _GYMNASIUM_ROBOTICS_PACKAGE not in sys.modules:
raise ImportError(
"Gymnasium Robotics is required for env_backend='robotics' and "
"Gymnasium MuJoCo v2/v3 task ids. Install it with "
"`pip install gymnasium-robotics` or `pip install torchwm[robotics]`."
)
before_ids = _registry_ids()
if _GYMNASIUM_ROBOTICS_PACKAGE in sys.modules:
module = sys.modules[_GYMNASIUM_ROBOTICS_PACKAGE]
else:
module = importlib.import_module(_GYMNASIUM_ROBOTICS_PACKAGE)
# Some Gymnasium third-party environment packages expose registration via
# gym.register_envs(package). Call it only when import did not reveal any
# Gymnasium Robotics specs to avoid duplicate-registration errors.
if not _robotics_ids_from_registry() and before_ids == _registry_ids():
register_envs = getattr(gym, "register_envs", None)
if callable(register_envs):
register_envs(module)
return module
[docs]
def list_gymnasium_robotics_envs() -> list[str]:
"""List all Gymnasium Robotics ids registered by the installed package.
Returns an empty list when the optional dependency is not installed. When it
is installed, the list is derived from Gymnasium's registry rather than a
hand-maintained subset, so newly added Robotics environments are exposed
automatically.
"""
try:
register_gymnasium_robotics_envs()
except ImportError:
return []
return sorted(_robotics_ids_from_registry())
[docs]
def make_gymnasium_env_with_robotics_fallback(
env: str,
*,
render_mode: str = "rgb_array",
gym_kwargs: dict[str, Any] | None = None,
**kwargs,
):
"""Create a Gymnasium env and retry after Robotics registration if needed."""
env_kwargs = dict(gym_kwargs or {})
env_kwargs.update(kwargs)
try:
return gym.make(str(env), render_mode=render_mode, **env_kwargs)
except ImportError as exc:
if not is_moved_mujoco_error(exc):
raise
register_gymnasium_robotics_envs()
try:
return gym.make(str(env), render_mode=render_mode, **env_kwargs)
except TypeError:
return gym.make(str(env), **env_kwargs)
except TypeError:
return gym.make(str(env), **env_kwargs)
[docs]
def make_robotics_env(
env: str,
*,
seed: int = 0,
size: tuple[int, int] = (64, 64),
render_mode: str = "rgb_array",
gym_kwargs: dict[str, Any] | None = None,
**kwargs,
):
"""Create a TorchWM image wrapper for a Gymnasium Robotics environment.
Args:
env: Any environment id registered by ``gymnasium-robotics``.
seed: Seed forwarded to ``GymImageEnv``.
size: Target ``(height, width)`` image size.
render_mode: Render mode forwarded to ``gymnasium.make``.
gym_kwargs: Optional keyword arguments forwarded to ``gymnasium.make``.
**kwargs: Additional keyword arguments forwarded to ``gymnasium.make``.
Returns:
A ``GymImageEnv`` that emits ``{"image": uint8[C, H, W]}`` observations.
"""
register_gymnasium_robotics_envs()
base_env = make_gymnasium_env_with_robotics_fallback(
env,
render_mode=render_mode,
gym_kwargs=gym_kwargs,
**kwargs,
)
return GymImageEnv(base_env, seed=seed, size=size, render_mode=render_mode)