Source code for world_models.envs.brax_env

from __future__ import annotations

import importlib
import importlib.util
from typing import Any

import gymnasium as gym
import numpy as np
from PIL import Image


[docs] def make_brax_env(env, **kwargs): """Create a TorchWM image wrapper for Brax environments. Args: env: Brax environment name (for example, ``"ant"``) or a pre-built Brax environment object exposing ``reset(rng)`` and ``step(state, action)``. **kwargs: Additional keyword arguments passed to :class:`BraxImageEnv`. Returns: BraxImageEnv: A Gym-like wrapper that returns ``{"image": (C, H, W)}`` observations and exposes continuous actions in the Brax ``[-1, 1]`` range. """ return BraxImageEnv(env=env, **kwargs)
[docs] class BraxImageEnv: """Gym-like adapter for training TorchWM world models on Brax tasks. Brax environments are functional JAX environments: ``reset`` consumes a PRNG key and returns a state, while ``step`` consumes the previous state plus an action and returns the next state. This adapter stores the Brax state between calls and converts state observations into image observations compatible with pixel-based TorchWM agents such as Dreamer. If a Brax renderer is not available, vector observations are rendered as deterministic feature-band images so training code can still consume a pixel stream. The original vector observation is also exposed through ``info["vector_observation"]`` after ``step`` for diagnostics. """ def __init__( self, env, seed: int = 0, size: tuple[int, int] = (64, 64), backend: str | None = None, episode_length: int | None = None, auto_reset: bool = False, jit: bool = True, suppress_warp_warnings: bool = True, **env_kwargs, ): self._size = (int(size[0]), int(size[1])) self._seed = int(seed) self._jit = bool(jit) self._state = None install_hint = "Install Brax support with `pip install torchwm[brax]`." self._jax = _require_module("jax", install_hint) self._jnp = _require_module("jax.numpy", install_hint) self._suppress_warp_warnings = bool(suppress_warp_warnings) self._brax_envs = _require_module( "brax.envs", install_hint, suppress_warp_warnings=self._suppress_warp_warnings, ) self._env = self._make_env( env, backend=backend, episode_length=episode_length, auto_reset=auto_reset, env_kwargs=env_kwargs, ) self._rng = self._jax.random.PRNGKey(self._seed) self._reset_fn = ( self._jax.jit(self._env.reset) if self._jit else self._env.reset ) self._step_fn = self._jax.jit(self._env.step) if self._jit else self._env.step action_size = getattr(self._env, "action_size", None) if action_size is None: action_size = getattr(self._env, "action_size", None) if action_size is None: raise ValueError("Brax environment must define an action_size attribute.") self._action_size = int(action_size) self._action_space = gym.spaces.Box( low=-1.0, high=1.0, shape=(self._action_size,), dtype=np.float32, ) self._observation_space = gym.spaces.Dict( { "image": gym.spaces.Box( low=0, high=255, shape=(3, self._size[0], self._size[1]), dtype=np.uint8, ) } ) def _make_env( self, env, backend: str | None, episode_length: int | None, auto_reset: bool, env_kwargs: dict[str, Any], ): if not isinstance(env, str): return env kwargs = dict(env_kwargs) if backend is not None: kwargs.setdefault("backend", backend) if episode_length is None: return self._brax_envs.get_environment(env, **kwargs) return self._brax_envs.create( env_name=env, episode_length=int(episode_length), action_repeat=1, auto_reset=bool(auto_reset), **kwargs, ) @property def observation_space(self): return self._observation_space @property def action_space(self): return self._action_space @property def max_episode_steps(self): for name in ("episode_length", "eps_length", "_episode_length"): value = getattr(self._env, name, None) if value is not None: return int(value) return 1000 def _split_key(self): self._rng, key = self._jax.random.split(self._rng) return key def _to_numpy(self, value): return np.asarray(self._jax.device_get(value)) def _vector_to_image(self, vector): vec = np.asarray(vector, dtype=np.float32).reshape(-1) if vec.size == 0: return np.zeros((self._size[0], self._size[1], 3), dtype=np.uint8) finite = np.isfinite(vec) if not finite.all(): vec = np.where(finite, vec, 0.0) vmin = float(vec.min()) vmax = float(vec.max()) if vmax > vmin: vec = (vec - vmin) / (vmax - vmin) else: vec = np.zeros_like(vec) image = np.zeros((self._size[0], self._size[1], 3), dtype=np.uint8) bands = min(16, vec.size) band_w = max(1, self._size[1] // max(1, bands)) for i in range(bands): start = i * band_w end = ( self._size[1] if i == bands - 1 else min(self._size[1], start + band_w) ) image[:, start:end, :] = int(255.0 * float(vec[i])) return image def _obs_to_hwc_image(self, obs): arr = self._to_numpy(obs) if arr.ndim == 1: image = self._vector_to_image(arr) elif arr.ndim == 2: image = np.repeat(arr[..., None], 3, axis=-1) elif arr.ndim == 3: image = arr if image.shape[-1] not in (1, 3, 4) and image.shape[0] in (1, 3, 4): image = image.transpose(1, 2, 0) if image.shape[-1] == 1: image = np.repeat(image, 3, axis=-1) elif image.shape[-1] == 4: image = image[..., :3] else: image = self._vector_to_image(arr) if image.dtype != np.uint8: image = image.astype(np.float32) if image.size > 0 and image.max() <= 1.0 and image.min() >= 0.0: image = (image * 255.0).clip(0, 255).astype(np.uint8) else: image = image.clip(0, 255).astype(np.uint8) return image def _to_chw_uint8_image(self, obs): image = self._obs_to_hwc_image(obs) if image.shape[0] != self._size[0] or image.shape[1] != self._size[1]: image = np.array( Image.fromarray(image).resize( (self._size[1], self._size[0]), Image.BILINEAR ) ) return image.transpose(2, 0, 1).copy() def _state_to_obs(self, state): return {"image": self._to_chw_uint8_image(state.obs)} def _metrics_to_info(self, state): info = {} for source_name in ("metrics", "info"): source = getattr(state, source_name, None) if isinstance(source, dict): for key, value in source.items(): try: info[key] = self._to_numpy(value) except TypeError: info[key] = value return info
[docs] def reset(self): self._state = self._reset_fn(self._split_key()) return self._state_to_obs(self._state)
[docs] def step(self, action): if self._state is None: raise RuntimeError("Must call reset() before step().") clipped = np.clip( np.asarray(action, dtype=np.float32).reshape(self._action_size), -1.0, 1.0, ) brax_action = self._jnp.asarray(clipped) self._state = self._step_fn(self._state, brax_action) reward = float(np.asarray(self._to_numpy(self._state.reward)).reshape(())) done = bool(np.asarray(self._to_numpy(self._state.done)).reshape(())) info = self._metrics_to_info(self._state) info.setdefault("discount", np.array(0.0 if done else 1.0, dtype=np.float32)) info["action"] = clipped.copy() info["vector_observation"] = self._to_numpy(self._state.obs).copy() return self._state_to_obs(self._state), reward, done, info
[docs] def render(self, *args, **kwargs): if self._state is None: raise RuntimeError("No frame available. Call reset() before render().") return self._to_chw_uint8_image(self._state.obs).transpose(1, 2, 0).copy()
[docs] def close(self): self._state = None
def _require_module( module_name: str, install_hint: str, *, suppress_warp_warnings: bool = False ): parent_name = module_name.split(".", 1)[0] if importlib.util.find_spec(parent_name) is None: raise ImportError( f"Missing optional dependency `{parent_name}`. {install_hint}" ) if importlib.util.find_spec(module_name) is None: raise ImportError( f"Missing optional dependency `{module_name}`. {install_hint}" ) # Some optional backends (notably MuJoCo/MJX's Warp shim) print noisy # import-time messages like: # Failed to import warp: No module named 'warp' # Failed to import mujoco_warp: No module named 'mujoco_warp' # These messages are harmless when the optional backend is not present # but pollute console output during tests and normal runs. When # `suppress_warp_warnings=True` filter those two lines while replaying # any other import output. if suppress_warp_warnings and module_name.startswith("brax"): import io import sys from contextlib import redirect_stdout, redirect_stderr buf = io.StringIO() # Capture stdout/stderr produced during import. with redirect_stdout(buf), redirect_stderr(buf): module = importlib.import_module(module_name) # Replay any captured lines except the known Warp messages. captured = buf.getvalue().splitlines() original_stdout = sys.stdout for line in captured: if line.startswith("Failed to import warp:"): continue if line.startswith("Failed to import mujoco_warp:"): continue print(line, file=original_stdout) return module return importlib.import_module(module_name)