from typing import Any
import pdb
from collections import defaultdict
from pprint import pprint
import os
import pickle
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence as kl_div
from tqdm import trange
from world_models.utils.utils import (
load_memory,
get_combined_params,
save_frames,
plot_metrics,
get_mask,
)
from world_models.models.rssm import RecurrentStateSpaceModel
from world_models.controller.rssm_policy import RSSMPolicy
from world_models.controller.rollout_generator import RolloutGenerator
from world_models.utils.utils import TorchImageEnvWrapper
from world_models.memory.planet_memory import Memory
BIT_DEPTH = 5
FREE_NATS = 2
STATE_SIZE = 200
LATENT_SIZE = 30
EMBEDDING_SIZE = 1024
[docs]
def evaluate(memory: Any, model: Any, path: str, eps: Any) -> None:
"""Run one RSSM reconstruction/prediction evaluation and save visual outputs.
Decodes priors/posteriors for a sampled sequence and writes frame grids
for qualitative inspection.
"""
model.eval()
device = next(model.parameters()).device
# sample and convert to torch tensors on the model device
(x, u, _, _), lens = memory.sample(1)
x = torch.tensor(x).float().to(device)
u = torch.tensor(u).float().to(device)
if x.dim() == 4:
x = x.unsqueeze(0)
if u.dim() == 2:
u = u.unsqueeze(0)
elif u.dim() == 1:
u = u.unsqueeze(0).unsqueeze(-1)
states, priors, posteriors = model(x, u)
prior_means = [p[0] for p in priors]
post_means = [p[0] for p in posteriors]
pred1_list = [model.decoder(states[t], prior_means[t]) for t in range(len(states))]
pred2_list = [model.decoder(states[t], post_means[t]) for t in range(len(states))]
# stack => [T, B, C, H, W]; save_frames expects target [T+1, C, H, W] and preds [T, C, H, W]
pred1 = torch.stack(pred1_list) # [T, B, C, H, W]
pred2 = torch.stack(pred2_list)
# move preds to CPU and reduce batch dim for saving (use batch 0)
save_frames(x[0].cpu(), pred1[:, 0].cpu(), pred2[:, 0].cpu(), f"{path}_{eps}")
[docs]
def main() -> None:
"""Standalone training loop for RSSM with generated replay fallback support.
Initializes environment/policy/memory, trains over episodes, logs metrics,
and periodically evaluates and checkpoints the model.
"""
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
print("WARNING: CUDA not available, using CPU")
free_nats_tensor = torch.full((1,), FREE_NATS).to(device)
rssm = RecurrentStateSpaceModel(1, STATE_SIZE, LATENT_SIZE, EMBEDDING_SIZE).to(
device
)
optimizer = torch.optim.Adam(get_combined_params(rssm), lr=1e-3)
# create a planning policy that uses the trained rssm
policy = RSSMPolicy(
model=rssm,
planning_horizon=12,
num_candidates=1000,
num_iterations=5,
top_candidates=50,
device=device,
)
# create an env wrapper and rollout generator (example env name)
env = TorchImageEnvWrapper("CartPole-v1", BIT_DEPTH)
rollout_gen = RolloutGenerator(env, device, policy=policy)
# fallback: if replay files don't exist, generate and save them with rollouts
def ensure_memory(
path: str, n_warmup: int = 25, mem_size: int = 1000, random_policy: bool = True
) -> Memory:
if os.path.exists(path):
return load_memory(path, device)
print(f"{path} not found → generating {n_warmup} episodes with rollout_gen")
mem = Memory(mem_size)
eps = rollout_gen.rollout_n(n=n_warmup, random_policy=random_policy)
mem.append(eps)
with open(path, "wb") as f:
pickle.dump(mem, f)
mem.device = device
return mem
test_data = ensure_memory("test_exp_replay.pth", n_warmup=10, random_policy=True)
train_data = ensure_memory("train_exp_replay.pth", n_warmup=50, random_policy=False)
# warm up training memory with a few rollouts
new_eps = rollout_gen.rollout_n(n=5, random_policy=False)
train_data.append(new_eps)
global_metrics = defaultdict(list)
for i in trange(1000, desc="# Episode: ", leave=False):
metrics = train_rssm(train_data, rssm, optimizer, record_grads=False)
for k, v in metrics.items():
global_metrics[k].extend(metrics[k])
plot_metrics(global_metrics, path="results/test_rssm", prefix="TRAIN_")
if (i + 1) % 10 == 0:
evaluate(test_data, rssm, "results/test_rssm/eps", i + 1)
if (i + 1) % 25 == 0:
torch.save(rssm.state_dict(), f"results/test_rssm/ckpt_{i + 1}.pth")
if os.getenv("TRAIN_RSSM_DEBUG", "0") == "1":
pdb.set_trace()
if __name__ == "__main__":
main()