Source code for world_models.envs.wrappers

import gym
from PIL import Image
import numpy as np
import datetime
import uuid


[docs] class TimeLimit: """Terminate episodes after a fixed number of wrapper steps. If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of `1.0` for downstream learners. """ def __init__(self, env, duration): self._env = env self._duration = duration self._step = None def __getattr__(self, name): return getattr(self._env, name)
[docs] def step(self, action): assert self._step is not None, "Must reset environment." obs, reward, done, info = self._env.step(action) self._step += 1 if self._step >= self._duration: done = True if "discount" not in info: info["discount"] = np.array(1.0).astype(np.float32) self._step = None return obs, reward, done, info
[docs] def reset(self): self._step = 0 return self._env.reset()
[docs] class ActionRepeat: """Repeat each action for a fixed number of environment steps. Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers. """ def __init__(self, env, amount): self._env = env self._amount = amount def __getattr__(self, name): return getattr(self._env, name)
[docs] def step(self, action): done = False total_reward = 0 current_step = 0 while current_step < self._amount and not done: obs, reward, done, info = self._env.step(action) total_reward += reward current_step += 1 return obs, total_reward, done, info
[docs] class NormalizeActions: """Expose a normalized `[-1, 1]` action space for bounded continuous controls. Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment. """ def __init__(self, env): self._env = env self._mask = np.logical_and( np.isfinite(env.action_space.low), np.isfinite(env.action_space.high) ) self._low = np.where(self._mask, env.action_space.low, -1) self._high = np.where(self._mask, env.action_space.high, 1) def __getattr__(self, name): return getattr(self._env, name) @property def action_space(self): low = np.where(self._mask, -np.ones_like(self._low), self._low) high = np.where(self._mask, np.ones_like(self._low), self._high) return gym.spaces.Box(low, high, dtype=np.float32)
[docs] def step(self, action): original = (action + 1) / 2 * (self._high - self._low) + self._low original = np.where(self._mask, original, action) return self._env.step(original)
[docs] class ObsDict: """Convert scalar/array observations into a dictionary observation format. This harmonizes outputs for code paths that expect keyed observations (for example `{"image": ...}` style world model inputs). """ def __init__(self, env, key="obs"): self._env = env self._key = key def __getattr__(self, name): return getattr(self._env, name) @property def observation_space(self): spaces = {self._key: self._env.observation_space} return gym.spaces.Dict(spaces) @property def action_space(self): return self._env.action_space
[docs] def step(self, action): obs, reward, done, info = self._env.step(action) obs = {self._key: np.array(obs)} return obs, reward, done, info
[docs] def reset(self): obs = self._env.reset() obs = {self._key: np.array(obs)} return obs
[docs] class OneHotAction: """Wrap discrete-action environments to accept one-hot action vectors. The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment. """ def __init__(self, env): assert isinstance(env.action_space, gym.spaces.Discrete) self._env = env def __getattr__(self, name): return getattr(self._env, name) @property def action_space(self): shape = (self._env.action_space.n,) space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) space.sample = self._sample_action return space
[docs] def step(self, action): index = np.argmax(action).astype(int) reference = np.zeros_like(action) reference[index] = 1 if not np.allclose(reference, action): raise ValueError(f"Invalid one-hot action:\n{action}") return self._env.step(index)
[docs] def reset(self): return self._env.reset()
def _sample_action(self): actions = self._env.action_space.n index = self._random.randint(0, actions) reference = np.zeros(actions, dtype=np.float32) reference[index] = 1.0 return reference
[docs] class RewardObs: """Augment observations with the latest scalar reward under `obs["reward"]`. Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference. """ def __init__(self, env): self._env = env def __getattr__(self, name): return getattr(self._env, name) @property def observation_space(self): spaces = self._env.observation_space.spaces assert "reward" not in spaces spaces["reward"] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) return gym.spaces.Dict(spaces)
[docs] def step(self, action): obs, reward, done, info = self._env.step(action) obs["reward"] = reward return obs, reward, done, info
[docs] def reset(self): obs = self._env.reset() obs["reward"] = 0.0 return obs
[docs] class ResizeImage: """Resize image-like observation entries to a target spatial size. The wrapper discovers image keys from `env.obs_space`, applies nearest neighbor resizing, and updates the advertised observation space shapes. """ def __init__(self, env, size=(64, 64)): self._env = env self._size = size self._keys = [ k for k, v in env.obs_space.items() if len(v.shape) > 1 and v.shape[:2] != size ] print(f'Resizing keys {",".join(self._keys)} to {self._size}.') if self._keys: self._Image = Image def __getattr__(self, name): if name.startswith("__"): raise AttributeError(name) try: return getattr(self._env, name) except AttributeError: raise ValueError(name) @property def obs_space(self): spaces = self._env.obs_space for key in self._keys: shape = self._size + spaces[key].shape[2:] spaces[key] = gym.spaces.Box(0, 255, shape, np.uint8) return spaces
[docs] def step(self, action): obs = self._env.step(action) for key in self._keys: obs[key] = self._resize(obs[key]) return obs
[docs] def reset(self): obs = self._env.reset() for key in self._keys: obs[key] = self._resize(obs[key]) return obs
def _resize(self, image): image = self._Image.fromarray(image) image = image.resize(self._size, self._Image.NEAREST) image = np.array(image) return image
[docs] class RenderImage: """Inject RGB renders from `env.render("rgb_array")` into observations. This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training. """ def __init__(self, env, key="image"): self._env = env self._key = key self._shape = self._env.render().shape def __getattr__(self, name): if name.startswith("__"): raise AttributeError(name) try: return getattr(self._env, name) except AttributeError: raise ValueError(name) @property def obs_space(self): spaces = self._env.obs_space spaces[self._key] = gym.spaces.Box(0, 255, self._shape, np.uint8) return spaces
[docs] def step(self, action): obs = self._env.step(action) obs[self._key] = self._env.render("rgb_array") return obs
[docs] def reset(self): obs = self._env.reset() obs[self._key] = self._env.render("rgb_array") return obs
class UUID(gym.Wrapper): """Gym wrapper that tracks a unique run identifier per environment reset. The ID combines timestamp and UUID and can be used to tag episodes or artifacts generated during data collection. """ def __init__(self, env): super().__init__(env) timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") self.id = f"{timestamp}-{str(uuid.uuid4().hex)}" def reset(self): timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") self.id = f"{timestamp}-{str(uuid.uuid4().hex)}" return self.env.reset()
[docs] class SelectAction(gym.Wrapper): """Gym wrapper for dictionary actions that forwards a selected key only. This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload. """ def __init__(self, env, key): super().__init__(env) self._key = key
[docs] def step(self, action): return self.env.step(action[self._key])