Plug TorchWM world models into Stable-Baselines3, TorchRL, and CleanRL#

This notebook mirrors the documentation tutorial. It is intentionally written as a template: replace the toy TinyWorldModel with a trained TorchWM checkpoint or your own adapter object. WorldModelEnv provides the Gymnasium surface, while small callables translate between library actions and model-specific latent dynamics.

1. Create a model-agnostic Gymnasium env#

The world model does not need to inherit from a TorchWM base class. The adapter functions below define reset, transition, and optional action conversion.

[ ]:
import gymnasium as gym
import numpy as np

from world_models.envs import WorldModelEnv


class TinyWorldModel:
    num_actions = 3

    def sample_initial_state(self, seed=None):
        rng = np.random.default_rng(seed)
        return rng.normal(size=(2,)).astype(np.float32)

    def predict_next_state(self, state, action):
        return (state + action[:2]).astype(np.float32)

    def decode_observation(self, state):
        return state.astype(np.float32)

    def predict_reward(self, state, action):
        return -float(np.linalg.norm(state))

    def predict_terminal(self, state):
        return bool(np.linalg.norm(state) < 0.05)


model = TinyWorldModel()
obs_space = gym.spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32)
action_space = gym.spaces.Discrete(model.num_actions)


def reset_world_model(model, seed=None, options=None):
    state = model.sample_initial_state(seed=seed)
    return {"state": state, "observation": model.decode_observation(state)}


def action_transform(model, action):
    actions = np.array([[0.1, 0.0], [-0.1, 0.0], [0.0, -0.1]], dtype=np.float32)
    return actions[int(action)]


def transition_world_model(model, state, action):
    next_state = model.predict_next_state(state, action)
    return {
        "state": next_state,
        "observation": model.decode_observation(next_state),
        "reward": model.predict_reward(next_state, action),
        "terminated": model.predict_terminal(next_state),
    }


env = WorldModelEnv(
    model,
    observation_space=obs_space,
    action_space=action_space,
    reset_fn=reset_world_model,
    transition_fn=transition_world_model,
    action_transform_fn=action_transform,
    max_episode_steps=25,
)

obs, info = env.reset(seed=0)
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
obs, reward, terminated, truncated, info

2. Stable-Baselines3#

Stable-Baselines3 accepts Gymnasium environments directly. Run its environment checker first, then train as usual. Use MultiInputPolicy for dictionary observations and CnnPolicy/MultiInputPolicy for image observations.

[ ]:
# Optional dependency example; uncomment after installing stable-baselines3.
# from stable_baselines3 import PPO
# from stable_baselines3.common.env_checker import check_env
# from stable_baselines3.common.monitor import Monitor
#
# check_env(env, warn=True)
# sb3_env = Monitor(env)
# agent = PPO("MlpPolicy", sb3_env, verbose=1)
# agent.learn(total_timesteps=10_000)

3. TorchRL#

TorchRL can wrap an existing Gymnasium environment with GymWrapper, so your adapter remains the same.

[ ]:
# Optional dependency example; uncomment after installing torchrl.
# from torchrl.envs import GymWrapper, TransformedEnv
# from torchrl.envs.transforms import Compose, DoubleToFloat, StepCounter
#
# base_env = GymWrapper(env)
# torchrl_env = TransformedEnv(base_env, Compose(DoubleToFloat(), StepCounter(max_steps=25)))
# td = torchrl_env.reset()
# td = torchrl_env.rand_step(td)

4. CleanRL#

CleanRL scripts typically define a make_env closure and vectorize it with Gymnasium. Put model loading inside the closure when each worker needs an independent model copy.

[ ]:
def make_env(seed: int):
    def thunk():
        return WorldModelEnv(
            model,
            observation_space=obs_space,
            action_space=action_space,
            reset_fn=reset_world_model,
            transition_fn=transition_world_model,
            action_transform_fn=action_transform,
            max_episode_steps=25,
            seed=seed,
        )

    return thunk


# CleanRL-style vectorization:
# envs = gym.vector.SyncVectorEnv([make_env(seed=i) for i in range(4)])
# obs, infos = envs.reset(seed=0)

5. Checklist#

  • Match observation_space to decoded observations.

  • Match action_space to the policy’s action format and use action_transform_fn for model-specific encodings.

  • Use max_episode_steps as the trusted imagination horizon.

  • Do not share mutable model state across vector workers unless it is safe.

  • Periodically evaluate imagined policies in the real environment to catch model exploitation.