Plug TorchWM world models into Stable-Baselines3, TorchRL, and CleanRL#
This notebook mirrors the documentation tutorial. It is intentionally written as a template: replace the toy TinyWorldModel with a trained TorchWM checkpoint or your own adapter object. WorldModelEnv provides the Gymnasium surface, while small callables translate between library actions and model-specific latent dynamics.
1. Create a model-agnostic Gymnasium env#
The world model does not need to inherit from a TorchWM base class. The adapter functions below define reset, transition, and optional action conversion.
[ ]:
import gymnasium as gym
import numpy as np
from world_models.envs import WorldModelEnv
class TinyWorldModel:
num_actions = 3
def sample_initial_state(self, seed=None):
rng = np.random.default_rng(seed)
return rng.normal(size=(2,)).astype(np.float32)
def predict_next_state(self, state, action):
return (state + action[:2]).astype(np.float32)
def decode_observation(self, state):
return state.astype(np.float32)
def predict_reward(self, state, action):
return -float(np.linalg.norm(state))
def predict_terminal(self, state):
return bool(np.linalg.norm(state) < 0.05)
model = TinyWorldModel()
obs_space = gym.spaces.Box(-np.inf, np.inf, shape=(2,), dtype=np.float32)
action_space = gym.spaces.Discrete(model.num_actions)
def reset_world_model(model, seed=None, options=None):
state = model.sample_initial_state(seed=seed)
return {"state": state, "observation": model.decode_observation(state)}
def action_transform(model, action):
actions = np.array([[0.1, 0.0], [-0.1, 0.0], [0.0, -0.1]], dtype=np.float32)
return actions[int(action)]
def transition_world_model(model, state, action):
next_state = model.predict_next_state(state, action)
return {
"state": next_state,
"observation": model.decode_observation(next_state),
"reward": model.predict_reward(next_state, action),
"terminated": model.predict_terminal(next_state),
}
env = WorldModelEnv(
model,
observation_space=obs_space,
action_space=action_space,
reset_fn=reset_world_model,
transition_fn=transition_world_model,
action_transform_fn=action_transform,
max_episode_steps=25,
)
obs, info = env.reset(seed=0)
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
obs, reward, terminated, truncated, info
2. Stable-Baselines3#
Stable-Baselines3 accepts Gymnasium environments directly. Run its environment checker first, then train as usual. Use MultiInputPolicy for dictionary observations and CnnPolicy/MultiInputPolicy for image observations.
[ ]:
# Optional dependency example; uncomment after installing stable-baselines3.
# from stable_baselines3 import PPO
# from stable_baselines3.common.env_checker import check_env
# from stable_baselines3.common.monitor import Monitor
#
# check_env(env, warn=True)
# sb3_env = Monitor(env)
# agent = PPO("MlpPolicy", sb3_env, verbose=1)
# agent.learn(total_timesteps=10_000)
3. TorchRL#
TorchRL can wrap an existing Gymnasium environment with GymWrapper, so your adapter remains the same.
[ ]:
# Optional dependency example; uncomment after installing torchrl.
# from torchrl.envs import GymWrapper, TransformedEnv
# from torchrl.envs.transforms import Compose, DoubleToFloat, StepCounter
#
# base_env = GymWrapper(env)
# torchrl_env = TransformedEnv(base_env, Compose(DoubleToFloat(), StepCounter(max_steps=25)))
# td = torchrl_env.reset()
# td = torchrl_env.rand_step(td)
4. CleanRL#
CleanRL scripts typically define a make_env closure and vectorize it with Gymnasium. Put model loading inside the closure when each worker needs an independent model copy.
[ ]:
def make_env(seed: int):
def thunk():
return WorldModelEnv(
model,
observation_space=obs_space,
action_space=action_space,
reset_fn=reset_world_model,
transition_fn=transition_world_model,
action_transform_fn=action_transform,
max_episode_steps=25,
seed=seed,
)
return thunk
# CleanRL-style vectorization:
# envs = gym.vector.SyncVectorEnv([make_env(seed=i) for i in range(4)])
# obs, infos = envs.reset(seed=0)
5. Checklist#
Match
observation_spaceto decoded observations.Match
action_spaceto the policy’s action format and useaction_transform_fnfor model-specific encodings.Use
max_episode_stepsas the trusted imagination horizon.Do not share mutable model state across vector workers unless it is safe.
Periodically evaluate imagined policies in the real environment to catch model exploitation.