Source code for world_models.envs.procgen_env

"""Procgen environment adapter for TorchWM image-based agents."""

from __future__ import annotations

import importlib
import importlib.util
import sys
from typing import Any

import gymnasium as gym
import numpy as np
from numpy.typing import NDArray
from PIL import Image

_PROCGEN_PACKAGE = "procgen"

PROCGEN_ENVS = [
    "bigfish",
    "bossfight",
    "caveflyer",
    "chaser",
    "climber",
    "coinrun",
    "dodgeball",
    "fruitbot",
    "heist",
    "jumper",
    "leaper",
    "maze",
    "miner",
    "ninja",
    "plunder",
    "starpilot",
]


def _require_procgen_env_class() -> type[Any]:
    """Return ``procgen.ProcgenEnv`` with a helpful optional-dependency error."""
    try:
        package_spec = importlib.util.find_spec(_PROCGEN_PACKAGE)
    except ValueError:
        package_spec = None
    if package_spec is None and _PROCGEN_PACKAGE not in sys.modules:
        raise ImportError(
            "Procgen support requires the optional 'procgen' package. "
            "Install it with `pip install torchwm[procgen]` or "
            "`pip install procgen`."
        )

    if _PROCGEN_PACKAGE in sys.modules:
        module = sys.modules[_PROCGEN_PACKAGE]
    else:
        module = importlib.import_module(_PROCGEN_PACKAGE)
    return getattr(module, "ProcgenEnv")


def _unbatch_procgen_info(info: Any) -> dict[str, Any]:
    """Normalize Procgen vector info to a single-environment info dict."""
    if isinstance(info, (list, tuple)):
        return dict(info[0]) if info else {}
    if not isinstance(info, dict):
        return {}

    unbatched: dict[str, Any] = {}
    for key, value in info.items():
        if isinstance(value, np.ndarray) and value.shape[:1] == (1,):
            unbatched[key] = value[0]
        elif isinstance(value, (list, tuple)) and len(value) == 1:
            unbatched[key] = value[0]
        else:
            unbatched[key] = value
    return unbatched


[docs] def list_procgen_envs() -> list[str]: """Return the Procgen game names understood by :class:`ProcgenImageEnv`.""" return list(PROCGEN_ENVS)
[docs] def normalize_procgen_env_name(env: str) -> str: """Normalize Procgen Gym ids and shorthand names to Procgen game names. Accepted forms include ``"coinrun"``, ``"procgen-coinrun-v0"``, and ``"procgen:procgen-coinrun-v0"``. """ name = str(env).strip() if ":" in name: name = name.split(":", 1)[1] if name.startswith("procgen-"): name = name[len("procgen-") :] if name.endswith("-v0"): name = name[: -len("-v0")] if name not in PROCGEN_ENVS: valid = ", ".join(PROCGEN_ENVS) raise ValueError(f"Unknown Procgen environment '{env}'. Valid names: {valid}.") return name
class _ProcgenActionSpace(gym.spaces.Box): """One-hot-like continuous action space for discrete Procgen actions.""" def __init__(self, n: int): self.n = int(n) super().__init__(low=-1.0, high=1.0, shape=(self.n,), dtype=np.float32) def sample(self, mask: Any = None, probability: Any = None) -> NDArray[np.float32]: del mask, probability idx = np.random.randint(0, self.n) action: NDArray[np.float32] = -np.ones((self.n,), dtype=np.float32) action[idx] = 1.0 return action
[docs] def make_procgen_env(env: str, **kwargs: Any) -> "ProcgenImageEnv": """Create a single-environment Procgen adapter. Args: env: Procgen game name or Gym-style id. **kwargs: Options forwarded to :class:`ProcgenImageEnv`. Returns: ProcgenImageEnv: TorchWM-compatible image wrapper exposing ``{"image": (3, H, W) uint8}`` observations and one-hot-like actions. """ return ProcgenImageEnv(env=env, **kwargs)
[docs] class ProcgenImageEnv: """Adapt Procgen's vector API to TorchWM's single-env image interface. The upstream ``procgen.ProcgenEnv`` API is vectorized, so this wrapper builds a one-environment vector and unwraps the leading batch dimension. Actions are exposed as a continuous one-hot-like ``Box[-1, 1]`` with one element per discrete Procgen action, matching TorchWM's other discrete image adapters. """ def __init__( self, env: str, seed: int = 0, size: tuple[int, int] = (64, 64), distribution_mode: str = "easy", num_levels: int = 0, start_level: int | None = None, action_n: int = 15, **procgen_kwargs: Any, ): ProcgenEnv = _require_procgen_env_class() self.env_name = normalize_procgen_env_name(env) self._size = (int(size[0]), int(size[1])) self._seed = int(seed) self._last_image: np.ndarray | None = None self._last_obs: Any = None max_episode_steps = int(procgen_kwargs.pop("max_episode_steps", 1000)) if start_level is None: start_level = self._seed self._env = ProcgenEnv( num_envs=1, env_name=self.env_name, distribution_mode=distribution_mode, num_levels=int(num_levels), start_level=int(start_level), **procgen_kwargs, ) base_action_space = getattr(self._env, "action_space", None) if base_action_space is not None and hasattr(base_action_space, "n"): action_n = int(base_action_space.n) self._discrete_n = int(action_n) self._action_space = _ProcgenActionSpace(self._discrete_n) self._observation_space = gym.spaces.Dict( { "image": gym.spaces.Box( low=0, high=255, shape=(3, self._size[0], self._size[1]), dtype=np.uint8, ) } ) self._max_episode_steps = max_episode_steps @property def observation_space(self) -> gym.spaces.Dict: return self._observation_space @property def action_space(self) -> _ProcgenActionSpace: return self._action_space @property def max_episode_steps(self) -> int: return self._max_episode_steps def _to_native_action(self, action: Any) -> tuple[int, NDArray[np.float32]]: vec = np.asarray(action, dtype=np.float32).reshape(-1) if vec.size == self._discrete_n and vec.size > 1: idx = int(np.argmax(vec)) elif vec.size >= 1: idx = int(round(float(vec[0]))) else: idx = 0 idx = int(np.clip(idx, 0, self._discrete_n - 1)) encoded = -np.ones((self._discrete_n,), dtype=np.float32) encoded[idx] = 1.0 return idx, encoded def _extract_rgb(self, obs: Any) -> np.ndarray: if isinstance(obs, tuple): obs = obs[0] if isinstance(obs, dict): for key in ("rgb", "image", "pixels", "observation"): if key in obs: return self._extract_rgb(obs[key]) for value in obs.values(): candidate = np.asarray(value) if candidate.ndim in (3, 4): return self._extract_rgb(candidate) raise RuntimeError("Procgen observation did not contain an RGB frame.") image = np.asarray(obs) if image.ndim == 4: image = image[0] if image.ndim != 3: raise RuntimeError( "Expected Procgen RGB observation with 3 dimensions, " f"got {image.shape}." ) if image.shape[-1] not in (1, 3, 4) and image.shape[0] in (1, 3, 4): image = image.transpose(1, 2, 0) if image.shape[-1] == 1: image = np.repeat(image, 3, axis=-1) elif image.shape[-1] == 4: image = image[..., :3] if image.dtype != np.uint8: image = image.astype(np.float32) if image.size > 0 and image.max() <= 1.0: image = image * 255.0 image = image.clip(0, 255).astype(np.uint8) return image def _to_chw_uint8_image(self, obs: Any) -> NDArray[np.uint8]: image = self._extract_rgb(obs) if image.shape[0] != self._size[0] or image.shape[1] != self._size[1]: image = np.array( Image.fromarray(image).resize( (self._size[1], self._size[0]), Image.BILINEAR ) ) return image.transpose(2, 0, 1).copy()
[docs] def reset(self) -> dict[str, NDArray[np.uint8]]: obs = self._env.reset() self._last_obs = obs image = self._to_chw_uint8_image(obs) self._last_image = image return {"image": image}
[docs] def step( self, action: Any ) -> tuple[dict[str, NDArray[np.uint8]], float, bool, dict[str, Any]]: native_action, model_action = self._to_native_action(action) action_batch = np.asarray([native_action], dtype=np.int32) obs, reward, done, info = self._env.step(action_batch) done_value = bool(np.asarray(done).reshape(-1)[0]) reward_value = float(np.asarray(reward).reshape(-1)[0]) info_value = _unbatch_procgen_info(info) if "discount" not in info_value: discount = 0.0 if done_value else 1.0 info_value["discount"] = np.array(discount, dtype=np.float32) info_value["action"] = np.asarray(model_action, dtype=np.float32).copy() self._last_obs = obs image = self._to_chw_uint8_image(obs) self._last_image = image return {"image": image}, reward_value, done_value, info_value
[docs] def render(self, *args: Any, **kwargs: Any) -> NDArray[np.uint8]: if self._last_image is None: raise RuntimeError("No frame available. Call reset() before render().") return self._last_image.transpose(1, 2, 0).copy()
[docs] def close(self) -> None: if hasattr(self._env, "close"): self._env.close()