import numpy as np
import gymnasium as gym
from gymnasium import spaces
from typing import Tuple, Dict, Optional, Any, overload
import cv2
[docs]
class DiamondAtariWrapper(gym.Wrapper):
"""
Atari wrapper for DIAMOND following the paper specifications:
- frameskip: number of frames to skip (default 4)
- max_noop: maximum number of noop actions at reset (default 30)
- terminate_on_life_loss: terminate episode when life is lost (default True)
- reward_clip: clip rewards to [-1, 0, 1] (default True)
- resize: resize observations to specified size (default 64x64)
"""
def __init__(
self,
env: gym.Env,
frameskip: int = 4,
max_noop: int = 30,
terminate_on_life_loss: bool = True,
reward_clip: bool = True,
resize: Optional[Tuple[int, int]] = (64, 64),
):
super().__init__(env)
self.frameskip = frameskip
self.max_noop = max_noop
self.terminate_on_life_loss = terminate_on_life_loss
self.reward_clip = reward_clip
self.resize = resize
self.lives = 0
self._last_lives = 0
if resize is not None:
self._height, self._width = resize
self.observation_space = spaces.Box(
low=0, high=255, shape=(self._height, self._width, 3), dtype=np.uint8
)
def _apply_frameskip(self, action: int) -> Tuple[Any, float, bool, Dict[str, Any]]:
"""Apply frameskip by repeating the action.
Returns (obs, total_reward, done, info) where `done` is a collapsed
boolean indicating termination/truncation for older gym APIs.
"""
total_reward = 0.0
done = False
info: Dict[str, Any] = {}
obs: Any = None
for _ in range(self.frameskip):
ret = self.env.step(action)
# gymnasium returns (obs, reward, terminated, truncated, info)
if isinstance(ret, tuple) and len(ret) == 5:
obs, reward, terminated, truncated, info = ret
else:
# older gym: (obs, reward, done, info)
obs, reward, single_done, info = ret # type: ignore[assignment]
terminated = bool(single_done)
truncated = False
total_reward += float(reward)
if terminated or (locals().get("truncated", False)):
done = True
break
if self.terminate_on_life_loss:
# ale attribute may or may not exist depending on backend; runtime
# checks are used here. Type-checkers don't know about `ale`, so
# use hasattr guards and ignore the attribute access for mypy.
if hasattr(self.env, "ale") and hasattr(
getattr(self.env, "ale"), "lives"
):
try:
# type: ignore[attr-defined]
self.lives = self.env.ale.lives()
except Exception:
# some backends expose lives as attribute or method; ignore failures
pass
if self.lives < self._last_lives and self.lives > 0:
done = True
info["life_lost"] = True
break
self._last_lives = self.lives
if self.reward_clip:
total_reward = float(np.clip(total_reward, -1, 1))
assert obs is not None
return obs, total_reward, done, info
[docs]
def step(self, action: int) -> Any: # type: ignore[override]
"""Step the environment.
For backwards compatibility with older gym APIs this wrapper returns a
4-tuple: (obs, reward, done, info). Internally it supports gymnasium's
5-tuple and collapses (terminated, truncated) into a single `done` bool.
"""
obs, reward, done, info = self._apply_frameskip(action)
if self.resize is not None:
obs = self._resize_obs(obs)
# Return legacy 4-tuple (obs, reward, done, info)
return obs, reward, bool(done), info
[docs]
def reset(self, **kwargs) -> Tuple[Any, Dict[str, Any]]:
obs, info = self.env.reset(**kwargs)
if self.resize is not None:
obs = self._resize_obs(obs)
if self.terminate_on_life_loss:
if hasattr(self.env, "ale") and hasattr(self.env.ale, "lives"):
self.lives = self.env.ale.lives()
else:
self.lives = 0
self._last_lives = self.lives
noops = np.random.randint(1, self.max_noop + 1)
for _ in range(noops):
action = self.env.action_space.sample()
if action == 0:
# gymnasium env.step returns (obs, reward, terminated, truncated, info)
step_ret = self.env.step(action)
if len(step_ret) == 5:
obs_step, _, terminated, truncated, _ = step_ret
done = bool(terminated or truncated)
else:
# fallback for older gym API
obs_step, _, done, _ = step_ret
if self.resize is not None:
obs_step = self._resize_obs(obs_step)
obs = obs_step
if done:
obs, info = self.env.reset(**kwargs)
if self.resize is not None:
obs = self._resize_obs(obs)
break
return obs, info
def _resize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Resize observation to target size."""
if obs.shape[:2] == (self._height, self._width):
return obs
# Use OpenCV for resize. Import is done at module scope to keep
# behavior consistent and straightforward.
obs = cv2.resize(obs, (self._width, self._height), interpolation=cv2.INTER_AREA)
return obs.astype(np.uint8)
[docs]
def make_diamond_atari_env(
game: str,
frameskip: int = 4,
max_noop: int = 30,
terminate_on_life_loss: bool = True,
reward_clip: bool = True,
resize: Tuple[int, int] = (64, 64),
seed: Optional[int] = None,
) -> DiamondAtariWrapper:
"""
Create a DIAMOND-compatible Atari environment.
Args:
game: Atari game name (e.g., "Breakout-v5")
frameskip: Number of frames to skip between actions
max_noop: Maximum number of noop actions at reset
terminate_on_life_loss: Whether to terminate on life loss
reward_clip: Whether to clip rewards to [-1, 0, 1]
resize: Target size for observations
seed: Random seed
Returns:
DiamondAtariWrapper: Configured Atari environment
"""
env = gym.make(
game,
obs_type="rgb",
frameskip=1,
repeat_action_probability=0.0,
full_action_space=False,
)
if seed is not None:
env.reset(seed=seed)
env.action_space.seed(seed)
env = DiamondAtariWrapper(
env=env,
frameskip=frameskip,
max_noop=max_noop,
terminate_on_life_loss=terminate_on_life_loss,
reward_clip=reward_clip,
resize=resize,
)
return env