Source code for world_models.envs.dmlab

from __future__ import annotations

import importlib
from collections.abc import Sequence
from typing import Any

import gymnasium as gym
import numpy as np
from PIL import Image

RGB_OBSERVATION = "RGB_INTERLEAVED"

# Compact discrete action set commonly used by DMLab agents. Each row maps to
# DeepMind Lab's native 7-element action vector:
# [look_lr, look_ud, strafe_lr, move_fb, fire, jump, crouch].
DEFAULT_ACTION_SET = np.array(
    [
        [0, 0, 0, 0, 0, 0, 0],  # noop
        [20, 0, 0, 0, 0, 0, 0],  # look right
        [-20, 0, 0, 0, 0, 0, 0],  # look left
        [0, 0, 0, 1, 0, 0, 0],  # move forward
        [0, 0, 0, -1, 0, 0, 0],  # move backward
        [0, 0, -1, 0, 0, 0, 0],  # strafe left
        [0, 0, 1, 0, 0, 0, 0],  # strafe right
        [20, 0, 0, 1, 0, 0, 0],  # forward + look right
        [-20, 0, 0, 1, 0, 0, 0],  # forward + look left
    ],
    dtype=np.intc,
)


DMLAB_LEVELS = [
    "rooms_collect_good_objects_train",
    "rooms_collect_good_objects_test",
    "rooms_exploit_deferred_effects_train",
    "rooms_exploit_deferred_effects_test",
    "rooms_select_nonmatching_object",
    "rooms_watermaze",
    "rooms_keys_doors_puzzle",
    "language_select_described_object",
    "language_select_located_object",
    "language_execute_random_task",
    "nav_maze_static_01",
    "nav_maze_static_02",
    "nav_maze_random_goal_01",
    "nav_maze_random_goal_02",
    "lt_chasm",
]


[docs] def make_dmlab_env(level: str, **kwargs: Any) -> "DMLabEnv": """Create a DeepMind Lab environment adapter for TorchWM. Args: level: DeepMind Lab level name, for example ``"rooms_collect_good_objects_train"``. **kwargs: Additional keyword arguments passed to :class:`DMLabEnv`. Returns: DMLabEnv: A Gym-like wrapper returning ``{"image": (C, H, W)}`` uint8 observations and normalized one-hot discrete actions. """ return DMLabEnv(level=level, **kwargs)
class _OneHotActionSpace(gym.spaces.Box): """Box action space that samples normalized one-hot vectors.""" def __init__(self, actions: int): self._actions = int(actions) super().__init__( low=-1.0, high=1.0, shape=(self._actions,), dtype=np.float32, ) def sample(self, mask: None = None, probability: None = None) -> np.ndarray: if mask is not None or probability is not None: raise ValueError("DMLab action sampling does not support masks.") action: np.ndarray = -np.ones((self._actions,), dtype=np.float32) action[np.random.randint(0, self._actions)] = 1.0 return action
[docs] class DMLabEnv: """Gym-style adapter for DeepMind Lab 3D environments. The native ``deepmind_lab`` API exposes RGB observations as HWC arrays and expects a seven-element integer action vector. This adapter presents a TorchWM-friendly image observation dict and a Box action space containing a one-hot vector in ``[-1, 1]`` so it composes with Dreamer's normalization wrappers. """ def __init__( self, level: str, seed: int = 0, size: tuple[int, int] = (64, 64), action_repeat: int = 4, action_set: Sequence[Sequence[int]] | np.ndarray | None = None, observations: Sequence[str] | None = None, config: dict[str, Any] | None = None, renderer: str = "hardware", **lab_kwargs: Any, ): self._level = str(level) self._seed = int(seed) self._episode = 0 self._size = (int(size[0]), int(size[1])) self._action_repeat = int(action_repeat) if self._action_repeat < 1: raise ValueError("action_repeat must be >= 1.") self._action_set = np.asarray( DEFAULT_ACTION_SET if action_set is None else action_set, dtype=np.intc ) if self._action_set.ndim != 2: raise ValueError("action_set must be a 2D array of DMLab action vectors.") self._observations = list(observations or [RGB_OBSERVATION]) if RGB_OBSERVATION not in self._observations: self._observations.insert(0, RGB_OBSERVATION) lab_config = {"width": str(self._size[1]), "height": str(self._size[0])} if config: lab_config.update({str(key): str(value) for key, value in config.items()}) self._config = lab_config deepmind_lab = _require_deepmind_lab() # type: ignore[no-untyped-call] self._env = deepmind_lab.Lab( self._level, self._observations, config=self._config, renderer=renderer, **lab_kwargs, ) self._last_obs: dict[str, np.ndarray] | None = None self._action_space = _OneHotActionSpace(self._action_set.shape[0]) self._observation_space = self._build_observation_space() # type: ignore[no-untyped-call] def _build_observation_space(self) -> gym.spaces.Dict: spaces = { "image": gym.spaces.Box( low=0, high=255, shape=(3, self._size[0], self._size[1]), dtype=np.uint8, ) } spec_fn = getattr(self._env, "observation_spec", None) if spec_fn is not None: try: specs = spec_fn() except Exception: specs = [] if isinstance(specs, dict): specs = [ {"name": key, **value} if isinstance(value, dict) else value for key, value in specs.items() ] for spec in specs or []: name = ( spec.get("name") if isinstance(spec, dict) else getattr(spec, "name", None) ) if not name or name == RGB_OBSERVATION: continue shape = ( spec.get("shape") if isinstance(spec, dict) else getattr(spec, "shape", ()) ) dtype = ( spec.get("dtype") if isinstance(spec, dict) else getattr(spec, "dtype", np.float32) ) np_dtype = np.dtype(dtype) if np.issubdtype(np_dtype, np.integer): info = np.iinfo(np_dtype) low, high = info.min, info.max else: low, high = -np.inf, np.inf spaces[name] = gym.spaces.Box(low, high, tuple(shape), dtype=np_dtype) return gym.spaces.Dict(spaces) @property def observation_space(self) -> gym.spaces.Dict: return self._observation_space @property def action_space(self) -> _OneHotActionSpace: return self._action_space @property def max_episode_steps(self) -> int: fps = int(self._config.get("fps", 60)) episode_length = int(self._config.get("episode_length_seconds", 60)) return max(1, (fps * episode_length) // self._action_repeat)
[docs] def reset(self) -> dict[str, np.ndarray]: seed = self._seed + self._episode self._episode += 1 self._env.reset(seed=seed) self._last_obs = self._read_obs() return self._last_obs
[docs] def step( self, action: np.ndarray ) -> tuple[dict[str, np.ndarray], float, bool, dict[str, Any]]: native_action = self._to_native_action(action) reward = float(self._env.step(native_action, num_steps=self._action_repeat)) done = not bool(self._env.is_running()) if done: obs = self._last_obs if self._last_obs is not None else self._empty_obs() else: obs = self._read_obs() self._last_obs = obs info = { "discount": np.array(0.0 if done else 1.0, dtype=np.float32), "action": self._action_to_one_hot(native_action), "dmlab_action": native_action.copy(), } return obs, reward, done, info
[docs] def render(self, *args: Any, **kwargs: Any) -> np.ndarray: if kwargs.get("mode", "rgb_array") != "rgb_array": raise ValueError("Only render mode 'rgb_array' is supported.") if self._last_obs is None: return np.zeros((self._size[0], self._size[1], 3), dtype=np.uint8) return self._last_obs["image"].transpose(1, 2, 0).copy()
[docs] def close(self) -> None: close = getattr(self._env, "close", None) if close is not None: close()
def _to_native_action(self, action: np.ndarray) -> np.ndarray: arr = np.asarray(action) if arr.shape == (self._action_set.shape[1],) and np.issubdtype( arr.dtype, np.integer ): return arr.astype(np.intc, copy=True) index = int(np.argmax(arr.reshape(-1))) return self._action_set[index].astype(np.intc, copy=True) def _action_to_one_hot(self, native_action: np.ndarray) -> np.ndarray: matches = np.all(self._action_set == native_action, axis=1) index = int(np.argmax(matches)) if matches.any() else 0 action: np.ndarray = -np.ones((self._action_set.shape[0],), dtype=np.float32) action[index] = 1.0 return action def _read_obs(self) -> dict[str, np.ndarray]: raw = self._env.observations() obs = {"image": self._to_chw_uint8(raw[RGB_OBSERVATION])} for key, value in raw.items(): if key != RGB_OBSERVATION: obs[key] = np.asarray(value) return obs def _empty_obs(self) -> dict[str, np.ndarray]: return {"image": np.zeros((3, self._size[0], self._size[1]), dtype=np.uint8)} def _to_chw_uint8(self, image: np.ndarray) -> np.ndarray: image = np.asarray(image) if image.ndim != 3: raise ValueError( f"Expected DMLab RGB image with 3 dims, got {image.shape}." ) if image.shape[0] in (1, 3, 4) and image.shape[-1] not in (1, 3, 4): image = image.transpose(1, 2, 0) if image.shape[-1] == 1: image = np.repeat(image, 3, axis=-1) elif image.shape[-1] == 4: image = image[..., :3] if image.dtype != np.uint8: image = image.astype(np.float32).clip(0, 255).astype(np.uint8) if image.shape[:2] != self._size: image = np.array( Image.fromarray(image).resize( (self._size[1], self._size[0]), Image.BILINEAR ) ) return image.transpose(2, 0, 1).copy()
def _require_deepmind_lab() -> Any: try: return importlib.import_module("deepmind_lab") except ImportError as exc: raise ImportError( "DeepMind Lab support requires the `deepmind_lab` Python module. " "Install DeepMind Lab manually or build it with `pip install dmlab-gym` " "and `dmlab-gym build`." ) from exc