from functools import partial
import os
import torch
import numpy as np
from world_models.utils.utils import to_tensor_obs, preprocess_img
from world_models.models.rssm import RecurrentStateSpaceModel
from world_models.controller.rssm_policy import RSSMPolicy
from world_models.controller.rollout_generator import RolloutGenerator
from world_models.utils.utils import (
TensorBoardMetrics,
save_video,
postprocess_img,
TorchImageEnvWrapper,
)
from world_models.memory.planet_memory import Memory, Episode
from world_models.training.train_planet import train as planet_train
[docs]
class Planet:
"""
High-level Planet wrapper.
Usage example:
from world_models.models.planet import Planet
p = Planet(env='CartPole-v1', bit_depth=5)
p.train(epochs=50)
"""
def __init__(
self,
env,
bit_depth=5,
device=None,
state_size=200,
latent_size=30,
embedding_size=1024,
memory_size=100,
policy_cfg=None,
headless=False,
max_episode_steps=None,
action_repeats=1,
results_dir=None,
):
if headless:
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
self.device = device or (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.bit_depth = bit_depth
if isinstance(env, str):
self.env = TorchImageEnvWrapper(env, bit_depth, None, action_repeats)
elif hasattr(env, "action_size"):
self.env = env
else:
self.env = self._wrap_raw_env(env, bit_depth, action_repeats)
self.rssm = RecurrentStateSpaceModel(
self.env.action_size, state_size, latent_size, embedding_size
).to(self.device)
self.optimizer = torch.optim.Adam(
self.rssm.parameters(), lr=3e-4, eps=1e-4, weight_decay=1e-5
)
policy_cfg = policy_cfg or {}
self.policy = RSSMPolicy(
model=self.rssm,
planning_horizon=policy_cfg.get("planning_horizon", 20),
num_candidates=policy_cfg.get("num_candidates", 1000),
num_iterations=policy_cfg.get("num_iterations", 10),
top_candidates=policy_cfg.get("top_candidates", 100),
device=self.device,
)
env_max_steps = (
max_episode_steps or getattr(self.env, "max_episode_steps", None) or 100
)
self.rollout_gen = RolloutGenerator(
self.env,
self.device,
policy=self.policy,
episode_gen=lambda: Episode(partial(postprocess_img, depth=self.bit_depth)),
max_episode_steps=env_max_steps,
)
self.memory = Memory(memory_size)
self.summary = None
self.results_dir = results_dir or "results/planet"
def _wrap_raw_env(self, env, bit_depth, action_repeats):
class SimpleEnvWrapper:
def __init__(self, env, bit_depth, action_repeats):
self.env = env
self.bit_depth = bit_depth
self.action_repeats = action_repeats
@property
def action_size(self):
if hasattr(self.env.action_space, "n"):
return 1
else:
return self.env.action_space.shape[0]
@property
def observation_size(self):
return (3, 64, 64)
@property
def max_episode_steps(self):
if (
hasattr(self.env, "_max_episode_steps")
and self.env._max_episode_steps is not None
):
return self.env._max_episode_steps
if hasattr(self.env, "spec") and hasattr(
self.env.spec, "max_episode_steps"
):
return self.env.spec.max_episode_steps
return 1000
def sample_random_action(self):
return torch.tensor(self.env.action_space.sample(), dtype=torch.float32)
def reset(self):
obs = self.env.reset()
if isinstance(obs, tuple):
obs = obs[0]
frame = self._get_frame(obs)
if frame is None or frame.size == 0:
raise RuntimeError("Environment returned invalid frame on reset")
x = to_tensor_obs(frame)
if torch.isnan(x).any() or torch.isinf(x).any():
print(
f"Invalid tensor before preprocessing: min={x.min()}, max={x.max()}"
)
raise ValueError("NaN/Inf in observation tensor")
preprocess_img(x, self.bit_depth)
if torch.isnan(x).any() or torch.isinf(x).any():
print(
f"Invalid tensor after preprocessing: min={x.min()}, max={x.max()}"
)
raise ValueError("NaN/Inf after preprocessing")
return x
def step(self, action):
if isinstance(action, torch.Tensor):
action = action.detach().numpy()
if hasattr(self.env.action_space, "n"):
action = int(np.clip(action, 0, self.env.action_space.n - 1))
obs, reward, term, trunc, info = self.env.step(action)
done = term or trunc
frame = self._get_frame(obs)
x = to_tensor_obs(frame)
preprocess_img(x, self.bit_depth)
return x, reward, done, info
def _get_frame(self, obs):
frame = self.env.render()
if isinstance(frame, tuple):
frame = frame[0]
if isinstance(frame, np.ndarray) and frame.ndim == 3:
return frame
if isinstance(obs, np.ndarray) and obs.ndim == 1:
vals = (obs - obs.min()) / (obs.max() - obs.min() + 1e-8)
canvas = np.zeros((64, 64, 3), dtype=np.uint8)
for i, v in enumerate(vals[:8]):
band = int(255 * v)
canvas[:, i * 8 : (i + 1) * 8, :] = band
return canvas
return np.zeros((64, 64, 3), dtype=np.uint8)
def __getattr__(self, name):
return getattr(self.env, name)
return SimpleEnvWrapper(env, bit_depth, action_repeats)
[docs]
def warmup(self, n_episodes=1, random_policy=True):
"""Collect n_episodes of rollouts into memory (used as warmup)."""
eps = self.rollout_gen.rollout_n(n=n_episodes, random_policy=random_policy)
self.memory.append(eps)
[docs]
def train(
self,
epochs=100,
steps_per_epoch=150,
batch_size=32,
H=50,
beta=1.0,
save_every=25,
record_grads=False,
results_dir=None,
scheduler_type="step",
scheduler_kwargs=None,
):
"""
High-level training loop. Delegates single-step training to the existing `train` function.
Args:
scheduler_type (str): Type of scheduler to use ("step", "cosine", "exponential", "plateau", None)
scheduler_kwargs (dict): Additional arguments for the scheduler
"""
# allow caller to override results dir for this training run
if results_dir is not None:
self.results_dir = results_dir
os.makedirs(self.results_dir, exist_ok=True)
self.summary = TensorBoardMetrics(self.results_dir)
# Initialize learning rate scheduler
scheduler = None
if scheduler_type is not None and scheduler_type.lower() != "none":
scheduler_kwargs = scheduler_kwargs or {}
if scheduler_type.lower() == "step":
# Default: reduce LR by factor of 0.5 every 50 epochs
step_size = scheduler_kwargs.get("step_size", 50)
gamma = scheduler_kwargs.get("gamma", 0.5)
scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=step_size, gamma=gamma
)
elif scheduler_type.lower() == "cosine":
# Cosine annealing with warm restarts
T_max = scheduler_kwargs.get("T_max", epochs)
eta_min = scheduler_kwargs.get("eta_min", 1e-6)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=T_max, eta_min=eta_min
)
elif scheduler_type.lower() == "exponential":
# Exponential decay
gamma = scheduler_kwargs.get("gamma", 0.95)
scheduler = torch.optim.lr_scheduler.ExponentialLR(
self.optimizer, gamma=gamma
)
elif scheduler_type.lower() == "plateau":
# Reduce on plateau (based on loss)
mode = scheduler_kwargs.get("mode", "min")
factor = scheduler_kwargs.get("factor", 0.5)
patience = scheduler_kwargs.get("patience", 10)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode=mode, factor=factor, patience=patience
)
else:
print(
f"Warning: Unknown scheduler type '{scheduler_type}'. No scheduler will be used."
)
if len(self.memory.episodes) == 0:
self.warmup(n_episodes=1, random_policy=True)
for ep in range(epochs):
metrics = {}
epoch_loss = 0 # Track average loss for ReduceLROnPlateau
for _ in range(steps_per_epoch):
train_metrics = planet_train(
self.memory,
self.rssm.train(),
self.optimizer,
self.device,
N=batch_size,
H=H,
beta=beta,
grads=record_grads,
)
for k, v in train_metrics.items():
if k not in metrics:
metrics[k] = []
if isinstance(v, dict):
for ik, iv in v.items():
metrics.setdefault(f"{k}/{ik}", []).append(iv)
else:
metrics[k].append(v)
compact = {}
for k, vs in metrics.items():
try:
compact[k] = np.mean(vs)
except Exception:
compact[k] = vs
# Calculate total loss for ReduceLROnPlateau
if "losses/kl" in compact and "losses/reconstruction" in compact:
epoch_loss = compact["losses/kl"] + compact["losses/reconstruction"]
if "losses/reward_pred" in compact:
epoch_loss += compact["losses/reward_pred"]
# Add current learning rate to metrics
current_lr = self.optimizer.param_groups[0]["lr"]
compact["learning_rate"] = current_lr
self.summary.update(compact)
# Step the scheduler
if scheduler is not None:
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
# ReduceLROnPlateau needs the metric to track
scheduler.step(epoch_loss)
else:
# Other schedulers just need step()
scheduler.step()
# Log learning rate changes
new_lr = self.optimizer.param_groups[0]["lr"]
if new_lr != current_lr:
print(
f"Epoch {ep+1}: Learning rate changed from {current_lr:.2e} to {new_lr:.2e}"
)
self.memory.append(self.rollout_gen.rollout_once(explore=True))
eval_episode, eval_frames, eval_metrics = self.rollout_gen.rollout_eval()
self.memory.append(eval_episode)
save_video(eval_frames, self.results_dir, f"vid_{ep+1}")
self.summary.update(eval_metrics)
if (ep + 1) % save_every == 0:
torch.save(
self.rssm.state_dict(),
os.path.join(self.results_dir, f"ckpt_{ep+1}.pth"),
)
return self.results_dir