import gym
from PIL import Image
import numpy as np
import datetime
import uuid
[docs]
class TimeLimit:
"""Terminate episodes after a fixed number of wrapper steps.
If the wrapped environment does not provide a discount flag at timeout,
the wrapper injects a default discount of `1.0` for downstream learners.
"""
def __init__(self, env, duration):
self._env = env
self._duration = duration
self._step = None
def __getattr__(self, name):
return getattr(self._env, name)
[docs]
def step(self, action):
assert self._step is not None, "Must reset environment."
obs, reward, done, info = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
if "discount" not in info:
info["discount"] = np.array(1.0).astype(np.float32)
self._step = None
return obs, reward, done, info
[docs]
def reset(self):
self._step = 0
return self._env.reset()
[docs]
class ActionRepeat:
"""Repeat each action for a fixed number of environment steps.
Rewards are accumulated and the loop stops early if the environment
terminates, mirroring common action-repeat behavior in world model papers.
"""
def __init__(self, env, amount):
self._env = env
self._amount = amount
def __getattr__(self, name):
return getattr(self._env, name)
[docs]
def step(self, action):
done = False
total_reward = 0
current_step = 0
while current_step < self._amount and not done:
obs, reward, done, info = self._env.step(action)
total_reward += reward
current_step += 1
return obs, total_reward, done, info
[docs]
class NormalizeActions:
"""Expose a normalized `[-1, 1]` action space for bounded continuous controls.
Incoming normalized actions are mapped back to the wrapped environment
action bounds before stepping the environment.
"""
def __init__(self, env):
self._env = env
self._mask = np.logical_and(
np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)
)
self._low = np.where(self._mask, env.action_space.low, -1)
self._high = np.where(self._mask, env.action_space.high, 1)
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
low = np.where(self._mask, -np.ones_like(self._low), self._low)
high = np.where(self._mask, np.ones_like(self._low), self._high)
return gym.spaces.Box(low, high, dtype=np.float32)
[docs]
def step(self, action):
original = (action + 1) / 2 * (self._high - self._low) + self._low
original = np.where(self._mask, original, action)
return self._env.step(original)
[docs]
class ObsDict:
"""Convert scalar/array observations into a dictionary observation format.
This harmonizes outputs for code paths that expect keyed observations
(for example `{"image": ...}` style world model inputs).
"""
def __init__(self, env, key="obs"):
self._env = env
self._key = key
def __getattr__(self, name):
return getattr(self._env, name)
@property
def observation_space(self):
spaces = {self._key: self._env.observation_space}
return gym.spaces.Dict(spaces)
@property
def action_space(self):
return self._env.action_space
[docs]
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs = {self._key: np.array(obs)}
return obs, reward, done, info
[docs]
def reset(self):
obs = self._env.reset()
obs = {self._key: np.array(obs)}
return obs
[docs]
class OneHotAction:
"""Wrap discrete-action environments to accept one-hot action vectors.
The wrapper validates one-hot inputs and converts them to integer action
indices before forwarding to the underlying environment.
"""
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
return space
[docs]
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
return self._env.step(index)
[docs]
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
[docs]
class RewardObs:
"""Augment observations with the latest scalar reward under `obs["reward"]`.
Useful for agents that consume reward as part of the observation stream
during model learning or recurrent policy inference.
"""
def __init__(self, env):
self._env = env
def __getattr__(self, name):
return getattr(self._env, name)
@property
def observation_space(self):
spaces = self._env.observation_space.spaces
assert "reward" not in spaces
spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
return gym.spaces.Dict(spaces)
[docs]
def step(self, action):
obs, reward, done, info = self._env.step(action)
obs["reward"] = reward
return obs, reward, done, info
[docs]
def reset(self):
obs = self._env.reset()
obs["reward"] = 0.0
return obs
[docs]
class ResizeImage:
"""Resize image-like observation entries to a target spatial size.
The wrapper discovers image keys from `env.obs_space`, applies nearest
neighbor resizing, and updates the advertised observation space shapes.
"""
def __init__(self, env, size=(64, 64)):
self._env = env
self._size = size
self._keys = [
k
for k, v in env.obs_space.items()
if len(v.shape) > 1 and v.shape[:2] != size
]
print(f'Resizing keys {",".join(self._keys)} to {self._size}.')
if self._keys:
self._Image = Image
def __getattr__(self, name):
if name.startswith("__"):
raise AttributeError(name)
try:
return getattr(self._env, name)
except AttributeError:
raise ValueError(name)
@property
def obs_space(self):
spaces = self._env.obs_space
for key in self._keys:
shape = self._size + spaces[key].shape[2:]
spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8)
return spaces
[docs]
def step(self, action):
obs = self._env.step(action)
for key in self._keys:
obs[key] = self._resize(obs[key])
return obs
[docs]
def reset(self):
obs = self._env.reset()
for key in self._keys:
obs[key] = self._resize(obs[key])
return obs
def _resize(self, image):
image = self._Image.fromarray(image)
image = image.resize(self._size, self._Image.NEAREST)
image = np.array(image)
return image
[docs]
class RenderImage:
"""Inject RGB renders from `env.render("rgb_array")` into observations.
This is useful when the base environment returns non-image observations
but a rendered camera view is needed for world-model training.
"""
def __init__(self, env, key="image"):
self._env = env
self._key = key
self._shape = self._env.render().shape
def __getattr__(self, name):
if name.startswith("__"):
raise AttributeError(name)
try:
return getattr(self._env, name)
except AttributeError:
raise ValueError(name)
@property
def obs_space(self):
spaces = self._env.obs_space
spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8)
return spaces
[docs]
def step(self, action):
obs = self._env.step(action)
obs[self._key] = self._env.render("rgb_array")
return obs
[docs]
def reset(self):
obs = self._env.reset()
obs[self._key] = self._env.render("rgb_array")
return obs
class UUID(gym.Wrapper):
"""Gym wrapper that tracks a unique run identifier per environment reset.
The ID combines timestamp and UUID and can be used to tag episodes or
artifacts generated during data collection.
"""
def __init__(self, env):
super().__init__(env)
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
def reset(self):
timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
self.id = f"{timestamp}-{str(uuid.uuid4().hex)}"
return self.env.reset()
[docs]
class SelectAction(gym.Wrapper):
"""Gym wrapper for dictionary actions that forwards a selected key only.
This enables integration with policies that emit action dicts while the
environment expects a single tensor/array action payload.
"""
def __init__(self, env, key):
super().__init__(env)
self._key = key
[docs]
def step(self, action):
return self.env.step(action[self._key])