from __future__ import annotations
from typing import Optional
import gymnasium as gym
import numpy as np
from PIL import Image
[docs]
def make_unity_mlagents_env(env_id: "Optional[str]" = None, **kwargs):
"""Create a Unity ML-Agents environment wrapper.
Factory function that instantiates a UnityMLAgentsEnv with the provided
keyword arguments. Suitable for integrating Unity-based environments
with Dreamer-style world model pipelines.
Args:
**kwargs: Keyword arguments passed to UnityMLAgentsEnv, including:
- file_name (str): Path to the Unity environment binary.
- behavior_name (str, optional): Name of the behavior to use.
- seed (int): Random seed (default: 0).
- size (tuple): Image size as (height, width) (default: (64, 64)).
- worker_id (int): Worker ID for multi-environment setup (default: 0).
- base_port (int): Base port for communication (default: 5005).
- no_graphics (bool): Disable graphics rendering (default: True).
- time_scale (float): Simulation time scale (default: 20.0).
- quality_level (int): Graphics quality level (default: 1).
- max_episode_steps (int): Max steps per episode (default: 1000).
Returns:
UnityMLAgentsEnv: A Gym-compatible wrapper for Unity environments.
"""
# `env_id` is accepted for API compatibility with generic factory
# callers that forward an env id as the first positional argument.
return UnityMLAgentsEnv(**kwargs)
[docs]
class UnityMLAgentsEnv:
"""Gym-like wrapper for Unity ML-Agents environments.
Provides a unified interface for Unity-based environments, converting
observations to image format compatible with pixel-based world models.
Features:
- Supports single-agent control with continuous action spaces.
- Returns observations as {"image": (C, H, W)} with uint8 values.
- Normalizes actions to [-1, 1] range.
- Includes rendered frames in observations for visual policies.
Args:
file_name (str): Path to the Unity environment binary.
behavior_name (str, optional): Name of the behavior to use. If None,
uses the first available behavior.
seed (int): Random seed for environment (default: 0).
size (tuple): Target image size as (height, width) (default: (64, 64)).
worker_id (int): Worker ID for multi-environment setup (default: 0).
base_port (int): Base port for Unity environment communication (default: 5005).
no_graphics (bool): Disable graphics rendering for faster simulation (default: True).
time_scale (float): Simulation time scale multiplier (default: 20.0).
quality_level (int): Graphics quality level 0-5 (default: 1).
max_episode_steps (int): Maximum steps per episode (default: 1000).
Attributes:
observation_space: Dict space with "image" key containing (3, H, W) Box.
action_space: Box space with actions in [-1, 1] range.
max_episode_steps: Maximum steps per episode.
Raises:
ValueError: If no behaviors found or action space is not continuous.
RuntimeError: If no agents available after reset.
"""
def __init__(
self,
file_name,
behavior_name=None,
seed=0,
size=(64, 64),
worker_id=0,
base_port=5005,
no_graphics=True,
time_scale=20.0,
quality_level=1,
max_episode_steps=1000,
):
from mlagents_envs.base_env import ActionTuple
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
)
self._ActionTuple = ActionTuple
self._size = (int(size[0]), int(size[1]))
self._max_episode_steps = int(max_episode_steps)
self._agent_id = None
self._last_image = None
self._engine_channel = EngineConfigurationChannel()
self._engine_channel.set_configuration_parameters(
width=self._size[1],
height=self._size[0],
quality_level=quality_level,
time_scale=float(time_scale),
)
self._env = UnityEnvironment(
file_name=file_name,
seed=seed,
worker_id=worker_id,
base_port=base_port,
no_graphics=no_graphics,
side_channels=[self._engine_channel],
)
self._env.reset()
behavior_names = list(self._env.behavior_specs.keys())
if not behavior_names:
raise ValueError("No Unity behaviors found in the environment.")
if behavior_name is None:
behavior_name = behavior_names[0]
if behavior_name not in self._env.behavior_specs:
raise ValueError(
f"Behavior '{behavior_name}' not found. Available: {behavior_names}"
)
self._behavior_name = behavior_name
self._spec = self._env.behavior_specs[self._behavior_name]
action_spec = self._spec.action_spec
if not action_spec.is_continuous():
raise ValueError(
"UnityMLAgentsEnv currently supports only continuous action spaces."
)
self._action_size = int(action_spec.continuous_size)
@property
def observation_space(self):
return gym.spaces.Dict(
{
"image": gym.spaces.Box(
low=0,
high=255,
shape=(3, self._size[0], self._size[1]),
dtype=np.uint8,
)
}
)
@property
def action_space(self):
return gym.spaces.Box(
low=-1.0, high=1.0, shape=(self._action_size,), dtype=np.float32
)
@property
def max_episode_steps(self):
return self._max_episode_steps
def _extract_agent_data(self, steps, preferred_agent_id):
agent_ids = np.asarray(getattr(steps, "agent_id", []))
if agent_ids.size == 0:
return None
if preferred_agent_id is None:
idx = 0
agent_id = int(agent_ids[idx])
else:
matches = np.where(agent_ids == preferred_agent_id)[0]
if matches.size == 0:
return None
idx = int(matches[0])
agent_id = int(preferred_agent_id)
obs_list = [np.asarray(obs[idx]) for obs in steps.obs]
rewards = np.asarray(getattr(steps, "reward", np.zeros(agent_ids.size)))
reward = float(rewards[idx]) if rewards.size > idx else 0.0
interrupted = False
if hasattr(steps, "interrupted"):
interrupted_arr = np.asarray(steps.interrupted)
if interrupted_arr.size > idx:
interrupted = bool(interrupted_arr[idx])
return agent_id, obs_list, reward, interrupted
def _vector_to_image(self, vector):
arr = np.asarray(vector, dtype=np.float32).reshape(-1)
if arr.size == 0:
return np.zeros((self._size[0], self._size[1], 3), dtype=np.uint8)
vmin = float(arr.min())
vmax = float(arr.max())
if vmax > vmin:
arr = (arr - vmin) / (vmax - vmin)
else:
arr = np.zeros_like(arr)
image = np.zeros((self._size[0], self._size[1], 3), dtype=np.uint8)
bands = min(arr.size, 8)
band_w = max(1, self._size[1] // max(1, bands))
for i in range(bands):
start = i * band_w
end = min(self._size[1], start + band_w)
image[:, start:end, :] = int(255.0 * float(arr[i]))
return image
def _to_hwc_uint8(self, obs):
arr = np.asarray(obs)
if arr.ndim == 1:
image = self._vector_to_image(arr)
elif arr.ndim == 2:
image = np.repeat(arr[..., None], 3, axis=-1)
elif arr.ndim == 3:
image = arr
# Handle CHW -> HWC if needed.
if image.shape[-1] not in (1, 3, 4) and image.shape[0] in (1, 3, 4):
image = image.transpose(1, 2, 0)
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
elif image.shape[-1] == 4:
image = image[..., :3]
else:
image = np.zeros((self._size[0], self._size[1], 3), dtype=np.uint8)
image = np.asarray(image)
if image.dtype != np.uint8:
image = image.astype(np.float32)
if image.size > 0 and image.max() <= 1.0:
image = (image * 255.0).clip(0, 255).astype(np.uint8)
else:
image = image.clip(0, 255).astype(np.uint8)
if image.shape[0] != self._size[0] or image.shape[1] != self._size[1]:
image = np.array(
Image.fromarray(image).resize(
(self._size[1], self._size[0]), Image.BILINEAR
)
)
return image
def _obs_list_to_chw_image(self, obs_list):
visual = None
for obs in obs_list:
arr = np.asarray(obs)
if arr.ndim == 3:
visual = arr
break
if visual is None and obs_list:
visual = np.asarray(obs_list[0])
if visual is None:
visual = np.zeros((self._size[0], self._size[1], 3), dtype=np.uint8)
image = self._to_hwc_uint8(visual)
return image.transpose(2, 0, 1).copy()
[docs]
def reset(self):
self._env.reset()
decision_steps, terminal_steps = self._env.get_steps(self._behavior_name)
data = self._extract_agent_data(decision_steps, preferred_agent_id=None)
if data is None:
data = self._extract_agent_data(terminal_steps, preferred_agent_id=None)
if data is None:
raise RuntimeError("No Unity agents were available after reset.")
self._agent_id, obs_list, _, _ = data
image = self._obs_list_to_chw_image(obs_list)
self._last_image = image
return {"image": image}
[docs]
def step(self, action):
if self._agent_id is None:
raise RuntimeError(
"Environment has terminated. Call reset() before step()."
)
action = np.asarray(action, dtype=np.float32).reshape(1, self._action_size)
action = np.clip(action, -1.0, 1.0)
self._env.set_actions(
self._behavior_name,
self._ActionTuple(continuous=action),
)
self._env.step()
decision_steps, terminal_steps = self._env.get_steps(self._behavior_name)
terminal_data = self._extract_agent_data(terminal_steps, self._agent_id)
interrupted = False
if terminal_data is not None:
_, obs_list, reward, interrupted = terminal_data
done = True
self._agent_id = None
else:
decision_data = self._extract_agent_data(decision_steps, self._agent_id)
if decision_data is None:
decision_data = self._extract_agent_data(decision_steps, None)
if decision_data is None:
raise RuntimeError("No decision step data found after Unity step.")
self._agent_id, obs_list, reward, _ = decision_data
done = False
image = self._obs_list_to_chw_image(obs_list)
self._last_image = image
info = {
"discount": np.array(0.0 if done else 1.0, dtype=np.float32),
"action": action[0].copy(),
}
if done:
info["interrupted"] = bool(interrupted)
return {"image": image}, float(reward), bool(done), info
[docs]
def render(self, *args, **kwargs):
if self._last_image is None:
raise RuntimeError("No frame available. Call reset() before render().")
return self._last_image.transpose(1, 2, 0).copy()
[docs]
def close(self):
if hasattr(self, "_env") and self._env is not None:
self._env.close()