World Model Env#

WorldModelEnv wraps a trained or adapter-backed world model as a Gymnasium-compliant environment. It lets you run simulated rollouts with the same API shape expected by RL libraries such as Stable-Baselines3, TorchRL, and CleanRL:

  • reset(seed=..., options=...) -> (observation, info)

  • step(action) -> (observation, reward, terminated, truncated, info)

  • observation_space, action_space, render(), and close()

The wrapper is model-agnostic. You can either expose common model methods (env_step, step, predict_step, predict, imagine_step, transition, or __call__) or provide explicit adapter callables. For end-to-end examples with Stable-Baselines3, TorchRL, and CleanRL, see the RL library integration tutorial and its notebook version.

Basic usage#

import gymnasium as gym
import numpy as np

from world_models.envs import WorldModelEnv


def transition(model, state, action):
    next_state = model.imagine_step(state, action)
    obs = model.decode(next_state)
    reward = model.reward(next_state)
    done = model.termination(next_state)
    return {
        "state": next_state,
        "observation": obs,
        "reward": reward,
        "terminated": done,
    }


env = WorldModelEnv(
    trained_model,
    observation_space=gym.spaces.Box(0, 255, shape=(3, 64, 64), dtype=np.uint8),
    action_space=gym.spaces.Box(-1.0, 1.0, shape=(6,), dtype=np.float32),
    initial_state=initial_latent,
    transition_fn=transition,
    max_episode_steps=50,
)

obs, info = env.reset(seed=0)
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())

Action adapters#

Use action_transform_fn when the RL library should optimize a different action representation than the world model consumes. For example, expose a Discrete action space to an RL library while converting each action into a learned one-hot or latent action vector before the model step:

def action_transform(model, action):
    one_hot = np.zeros(model.num_actions, dtype=np.float32)
    one_hot[int(action)] = 1.0
    return one_hot

env = WorldModelEnv(
    trained_model,
    observation_space=obs_space,
    action_space=gym.spaces.Discrete(trained_model.num_actions),
    transition_fn=transition,
    action_transform_fn=action_transform,
)

Factory and public API#

The direct factory is make_world_model_env:

from world_models.envs import make_world_model_env

env = make_world_model_env(
    trained_model,
    observation_space=obs_space,
    action_space=act_space,
    transition_fn=transition,
)

The top-level factory also includes a world-model backend:

import torchwm

env = torchwm.make_env(
    trained_model,
    backend="world-model",
    observation_space=obs_space,
    action_space=act_space,
    transition_fn=transition,
)

Aliases include world_model, model, and wm.

Accepted adapter return forms#

reset_fn may return:

  • obs

  • (obs, info)

  • (state, obs)

  • (state, obs, info)

  • a mapping with state, observation/obs/image, and optional info

transition_fn or model transition methods may return:

  • (obs, reward, terminated, truncated, info)

  • (obs, reward, done, info)

  • (state, obs, reward)

  • (state, obs, reward, terminated, truncated, info)

  • a mapping with state/next_state, observation/obs/image, reward, terminated/done, truncated, and optional info

If a transition omits a reward or termination flag, pass reward_fn and terminal_fn. Missing rewards default to 0.0; missing termination defaults to False. max_episode_steps sets truncated=True when the simulated rollout reaches the time limit.