import pdb
from collections import defaultdict
from pprint import pprint
import os
import pickle
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence as kl_div
from tqdm import trange
from world_models.utils.utils import (
load_memory,
get_combined_params,
save_frames,
plot_metrics,
get_mask,
)
from world_models.models.rssm import RecurrentStateSpaceModel
from world_models.controller.rssm_policy import RSSMPolicy
from world_models.controller.rollout_generator import RolloutGenerator
from world_models.utils.utils import TorchImageEnvWrapper
from world_models.memory.planet_memory import Memory
BIT_DEPTH = 5
FREE_NATS = 2
STATE_SIZE = 200
LATENT_SIZE = 30
EMBEDDING_SIZE = 1024
[docs]
def evaluate(memory, model, path, eps):
"""Run one RSSM reconstruction/prediction evaluation and save visual outputs.
Decodes priors/posteriors for a sampled sequence and writes frame grids
for qualitative inspection.
"""
model.eval()
device = next(model.parameters()).device
# sample and convert to torch tensors on the model device
(x, u, _, _), lens = memory.sample(1)
x = torch.tensor(x).float().to(device)
u = torch.tensor(u).float().to(device)
if x.dim() == 4:
x = x.unsqueeze(0)
if u.dim() == 2:
u = u.unsqueeze(0)
elif u.dim() == 1:
u = u.unsqueeze(0).unsqueeze(-1)
states, priors, posteriors = model(x, u)
prior_means = [p[0] for p in priors]
post_means = [p[0] for p in posteriors]
pred1_list = [model.decoder(states[t], prior_means[t]) for t in range(len(states))]
pred2_list = [model.decoder(states[t], post_means[t]) for t in range(len(states))]
# stack => [T, B, C, H, W]; save_frames expects target [T+1, C, H, W] and preds [T, C, H, W]
pred1 = torch.stack(pred1_list) # [T, B, C, H, W]
pred2 = torch.stack(pred2_list)
# move preds to CPU and reduce batch dim for saving (use batch 0)
save_frames(x[0].cpu(), pred1[:, 0].cpu(), pred2[:, 0].cpu(), f"{path}_{eps}")
[docs]
def main():
"""Standalone training loop for RSSM with generated replay fallback support.
Initializes environment/policy/memory, trains over episodes, logs metrics,
and periodically evaluates and checkpoints the model.
"""
global FREE_NATS
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
FREE_NATS = torch.full((1,), FREE_NATS).to(device)
rssm = RecurrentStateSpaceModel(1, STATE_SIZE, LATENT_SIZE, EMBEDDING_SIZE).to(
device
)
optimizer = torch.optim.Adam(get_combined_params(rssm), lr=1e-3)
# create a planning policy that uses the trained rssm
policy = RSSMPolicy(
model=rssm,
planning_horizon=12,
num_candidates=1000,
num_iterations=5,
top_candidates=50,
device=device,
)
# create an env wrapper and rollout generator (example env name)
env = TorchImageEnvWrapper("CartPole-v1", BIT_DEPTH)
rollout_gen = RolloutGenerator(env, device, policy=policy)
# fallback: if replay files don't exist, generate and save them with rollouts
def ensure_memory(path, n_warmup=25, mem_size=1000, random_policy=True):
if os.path.exists(path):
return load_memory(path, device)
print(f"{path} not found → generating {n_warmup} episodes with rollout_gen")
mem = Memory(mem_size)
eps = rollout_gen.rollout_n(n=n_warmup, random_policy=random_policy)
mem.append(eps)
with open(path, "wb") as f:
pickle.dump(mem, f)
mem.device = device
return mem
test_data = ensure_memory("test_exp_replay.pth", n_warmup=10, random_policy=True)
train_data = ensure_memory("train_exp_replay.pth", n_warmup=50, random_policy=False)
# warm up training memory with a few rollouts
new_eps = rollout_gen.rollout_n(n=5, random_policy=False)
train_data.append(new_eps)
global_metrics = defaultdict(list)
for i in trange(1000, desc="# Episode: ", leave=False):
metrics = train_rssm(train_data, rssm, optimizer, record_grads=False)
for k, v in metrics.items():
global_metrics[k].extend(metrics[k])
plot_metrics(global_metrics, path="results/test_rssm", prefix="TRAIN_")
if (i + 1) % 10 == 0:
evaluate(test_data, rssm, "results/test_rssm/eps", i + 1)
if (i + 1) % 25 == 0:
torch.save(rssm.state_dict(), f"results/test_rssm/ckpt_{i+1}.pth")
if os.getenv("TRAIN_RSSM_DEBUG", "0") == "1":
pdb.set_trace()
if __name__ == "__main__":
main()