from __future__ import annotations
import importlib
import importlib.util
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Any
import gymnasium as gym
import numpy as np
from PIL import Image
from world_models.envs.gym_env import GymImageEnv
from world_models.envs.robotics_env import (
make_gymnasium_env_with_robotics_fallback,
register_gymnasium_robotics_envs,
)
RewardFn = Callable[[Any, Any, np.ndarray, dict[str, Any]], float]
TerminalFn = Callable[[Any, Any, dict[str, Any]], bool]
def _load_mujoco():
if "mujoco" in sys.modules:
return sys.modules["mujoco"]
if importlib.util.find_spec("mujoco") is None:
raise ImportError(
"The native MuJoCo bindings are required for MuJoCoImageEnv. "
"Install them with `pip install mujoco` or `pip install torchwm[mujoco]`."
)
return importlib.import_module("mujoco")
def _validate_model_source(
*,
xml_path: str | Path | None,
xml_string: str | None,
binary_path: str | Path | None,
) -> None:
sources = [xml_path is not None, xml_string is not None, binary_path is not None]
if sum(sources) != 1:
raise ValueError(
"Provide exactly one of xml_path, xml_string, or binary_path for MuJoCoImageEnv."
)
def _is_native_model_source(model: str | Path) -> bool:
model_text = str(model)
return (
model_text.lstrip().startswith("<")
or model_text.endswith((".xml", ".mjb"))
or Path(model_text).exists()
)
def _infer_model_source(model: str | Path) -> dict[str, str | Path]:
model_text = str(model)
if model_text.lstrip().startswith("<"):
return {"xml_string": model_text}
if model_text.endswith(".mjb"):
return {"binary_path": model}
return {"xml_path": model}
[docs]
def make_mujoco_env_from_config(args, size: tuple[int, int]):
"""Build a MuJoCo image environment from a DreamerConfig-like object."""
native_kwargs = {
"seed": args.seed,
"size": size,
"camera": getattr(args, "mujoco_camera", None),
"frame_skip": int(getattr(args, "mujoco_frame_skip", 1)),
"reset_noise_scale": float(getattr(args, "mujoco_reset_noise_scale", 0.0)),
}
if getattr(args, "mujoco_xml_string", None) is not None:
return make_mujoco_env(xml_string=args.mujoco_xml_string, **native_kwargs)
if getattr(args, "mujoco_binary_path", None) is not None:
return make_mujoco_env(binary_path=args.mujoco_binary_path, **native_kwargs)
if getattr(args, "mujoco_xml_path", None) is not None:
return make_mujoco_env(xml_path=args.mujoco_xml_path, **native_kwargs)
gym_kwargs = {}
reset_noise_scale = native_kwargs["reset_noise_scale"]
if reset_noise_scale != 0.0:
gym_kwargs["reset_noise_scale"] = reset_noise_scale
return make_mujoco_env(
args.env,
seed=args.seed,
size=size,
gym_kwargs=gym_kwargs,
)
[docs]
class MuJoCoImageEnv:
"""Native MuJoCo environment adapter for pixel-based world-model training.
The adapter uses the low-level ``mujoco`` Python package directly: models are
compiled from MJCF XML strings/files or MJB binaries via ``mujoco.MjModel``;
simulation state lives in ``mujoco.MjData``; actions are written to
``data.ctrl``; and images are produced with ``mujoco.Renderer``. Observations
follow TorchWM's Dreamer-style contract: ``{"image": uint8[C, H, W]}``.
Native MuJoCo models do not define task rewards or episode termination by
themselves, so callers can supply ``reward_fn`` and ``terminal_fn`` callbacks.
By default, rewards are ``0.0`` and episodes terminate only through external
wrappers such as ``TimeLimit``.
"""
def __init__(
self,
xml_path: str | Path | None = None,
*,
xml_string: str | None = None,
binary_path: str | Path | None = None,
assets: dict[str, bytes] | None = None,
seed: int = 0,
size: tuple[int, int] = (64, 64),
camera: str | int | None = None,
reward_fn: RewardFn | None = None,
terminal_fn: TerminalFn | None = None,
frame_skip: int = 1,
reset_noise_scale: float = 0.0,
default_control_range: tuple[float, float] = (-1.0, 1.0),
):
_validate_model_source(
xml_path=xml_path,
xml_string=xml_string,
binary_path=binary_path,
)
self._mujoco = _load_mujoco()
self._size = (int(size[0]), int(size[1]))
self._camera = camera
self._reward_fn = reward_fn
self._terminal_fn = terminal_fn
self._frame_skip = max(1, int(frame_skip))
self._reset_noise_scale = float(reset_noise_scale)
self._rng = np.random.default_rng(seed)
self._closed = False
if xml_string is not None:
self.model = self._mujoco.MjModel.from_xml_string(xml_string, assets)
elif binary_path is not None:
self.model = self._mujoco.MjModel.from_binary_path(str(binary_path))
else:
self.model = self._mujoco.MjModel.from_xml_path(str(xml_path))
self.data = self._mujoco.MjData(self.model)
height, width = self._size
self._renderer = self._mujoco.Renderer(self.model, height=height, width=width)
self._observation_space = gym.spaces.Dict(
{
"image": gym.spaces.Box(
low=0,
high=255,
shape=(3, height, width),
dtype=np.uint8,
)
}
)
self._action_space = self._build_action_space(default_control_range)
self.max_episode_steps = 1000
def _build_action_space(
self, default_control_range: tuple[float, float]
) -> gym.spaces.Box:
action_dim = int(getattr(self.model, "nu", 0))
if action_dim <= 0:
return gym.spaces.Box(
low=np.zeros((0,), dtype=np.float32),
high=np.zeros((0,), dtype=np.float32),
dtype=np.float32,
)
ctrlrange = np.asarray(
getattr(self.model, "actuator_ctrlrange", []), dtype=np.float32
)
limited = np.asarray(
getattr(self.model, "actuator_ctrllimited", []), dtype=bool
)
if ctrlrange.shape == (action_dim, 2) and limited.shape[0] == action_dim:
default_low, default_high = default_control_range
low = np.where(limited, ctrlrange[:, 0], float(default_low)).astype(
np.float32
)
high = np.where(limited, ctrlrange[:, 1], float(default_high)).astype(
np.float32
)
else:
low = np.full(
(action_dim,), float(default_control_range[0]), dtype=np.float32
)
high = np.full(
(action_dim,), float(default_control_range[1]), dtype=np.float32
)
return gym.spaces.Box(low=low, high=high, dtype=np.float32)
@property
def observation_space(self):
return self._observation_space
@property
def action_space(self):
return self._action_space
def _render_chw(self) -> np.ndarray:
if self._camera is None:
self._renderer.update_scene(self.data)
else:
self._renderer.update_scene(self.data, camera=self._camera)
image = np.asarray(self._renderer.render())
if image.shape[:2] != self._size:
image = np.asarray(
Image.fromarray(image).resize(
(self._size[1], self._size[0]), Image.Resampling.BILINEAR
)
)
if image.ndim == 2:
image = np.repeat(image[..., None], 3, axis=-1)
if image.shape[-1] == 4:
image = image[..., :3]
return image.astype(np.uint8, copy=False).transpose(2, 0, 1).copy()
[docs]
def reset(self, seed: int | None = None):
if seed is not None:
self._rng = np.random.default_rng(seed)
self._mujoco.mj_resetData(self.model, self.data)
if self._reset_noise_scale > 0.0:
if getattr(self.model, "nq", 0):
self.data.qpos[:] += self._rng.normal(
0.0, self._reset_noise_scale, size=self.data.qpos.shape
)
if getattr(self.model, "nv", 0):
self.data.qvel[:] += self._rng.normal(
0.0, self._reset_noise_scale, size=self.data.qvel.shape
)
self._mujoco.mj_forward(self.model, self.data)
return {"image": self._render_chw()}
[docs]
def step(self, action):
action_arr = np.asarray(action, dtype=np.float32).reshape(
self.action_space.shape
)
clipped = np.clip(action_arr, self.action_space.low, self.action_space.high)
if clipped.size:
self.data.ctrl[:] = clipped
self._mujoco.mj_step(self.model, self.data, nstep=self._frame_skip)
info = {
"action": clipped.astype(np.float32, copy=True),
"time": float(getattr(self.data, "time", 0.0)),
"qpos": np.asarray(getattr(self.data, "qpos", []), dtype=np.float64).copy(),
"qvel": np.asarray(getattr(self.data, "qvel", []), dtype=np.float64).copy(),
}
if self._reward_fn is None:
reward = 0.0
else:
reward = float(self._reward_fn(self.model, self.data, clipped, info))
done = (
bool(self._terminal_fn(self.model, self.data, info))
if self._terminal_fn
else False
)
return {"image": self._render_chw()}, reward, done, info
[docs]
def render(self):
return self._render_chw().transpose(1, 2, 0).copy()
[docs]
def close(self):
if self._closed:
return
self._closed = True
close = getattr(self._renderer, "close", None)
if callable(close):
close()
[docs]
def make_mujoco_env(
model: str | Path | None = None,
*,
backend: str = "auto",
seed: int = 0,
size: tuple[int, int] = (64, 64),
render_mode: str = "rgb_array",
gym_kwargs: dict[str, Any] | None = None,
**kwargs,
):
"""Create one MuJoCo image environment factory for tasks and MJCF/MJB models.
Args:
model: Either a Gymnasium MuJoCo task id such as ``"Humanoid-v4"``,
an MJCF XML path/string, or an MJB binary path.
backend: ``"auto"`` infers native vs Gymnasium task mode. Use
``"native"`` for MJCF/MJB, ``"gymnasium"`` for task ids, or
``"robotics"`` for Gymnasium Robotics registrations.
seed: Seed forwarded to the image wrapper.
size: Target ``(height, width)`` image size.
render_mode: Render mode used for Gymnasium MuJoCo task ids.
gym_kwargs: Optional keyword arguments forwarded to ``gymnasium.make``
in task-id mode. Extra ``**kwargs`` are also forwarded there.
**kwargs: Native ``MuJoCoImageEnv`` options for MJCF/MJB mode, or
environment-constructor options for Gymnasium task-id mode.
Returns:
A TorchWM image environment returning ``{"image": uint8[C, H, W]}``.
"""
backend = backend.lower()
explicit_native_source = any(
key in kwargs for key in ("xml_path", "xml_string", "binary_path")
)
use_native = backend in {"native", "mjcf", "mjb"} or (
backend == "auto"
and (
explicit_native_source
or (model is not None and _is_native_model_source(model))
)
)
if use_native:
if model is not None:
kwargs.update(
{k: v for k, v in _infer_model_source(model).items() if k not in kwargs}
)
return MuJoCoImageEnv(seed=seed, size=size, **kwargs)
if backend not in {
"auto",
"gym",
"gymnasium",
"task",
"robotics",
"gymnasium_robotics",
}:
raise ValueError(
f"Unknown MuJoCo backend={backend!r}. Use 'auto', 'native', 'gymnasium', or 'robotics'."
)
if model is None:
raise ValueError(
"A Gymnasium MuJoCo environment id, XML path/string, or MJB path is required."
)
env_kwargs = dict(gym_kwargs or {})
env_kwargs.update(kwargs)
if backend in {"robotics", "gymnasium_robotics"}:
register_gymnasium_robotics_envs()
env = make_gymnasium_env_with_robotics_fallback(
str(model),
render_mode=render_mode,
gym_kwargs=env_kwargs,
)
return GymImageEnv(env, seed=seed, size=size, render_mode=render_mode)