Source code for world_models.training.train_planet

import pdb
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import trange
from functools import partial
import os

from torch.distributions import Normal
from torch.distributions.kl import kl_divergence

from world_models.utils.utils import (
    preprocess_img,
    bottle,
    TensorBoardMetrics,
    save_video,
    flatten_dict,
    postprocess_img,
    normalize_frames_for_saving,
)
from world_models.memory.planet_memory import Memory, Episode
from world_models.models.rssm import RecurrentStateSpaceModel
from world_models.controller.rssm_policy import RSSMPolicy
from world_models.controller.rollout_generator import RolloutGenerator


[docs] def train(memory, rssm, optimizer, device, N=32, H=50, beta=1.0, grads=False): """ Training implementation as indicated in: Learning Latent Dynamics for Planning from Pixels arXiv:1811.04551 (a.) The Standard Variational Bound Method using only single step predictions. """ free_nats = torch.ones(1, device=device) * 3.0 (x, u, r, t), lengths = memory.sample(N, H, time_first=True) x, u, r, t = [torch.tensor(arr).float().to(device) for arr in (x, u, r, t)] preprocess_img(x, depth=5) e_t = bottle(rssm.encoder, x) h_t, s_t = rssm.get_init_state(e_t[0]) states, priors, posteriors, posterior_samples = [], [], [], [] for i, a_t in enumerate(torch.unbind(u, dim=0)): h_t = rssm.deterministic_state_fwd(h_t, s_t, a_t) states.append(h_t) priors.append(rssm.state_prior(h_t)) posteriors.append(rssm.state_posterior(h_t, e_t[i + 1])) posterior_samples.append(Normal(*posteriors[-1]).rsample()) s_t = posterior_samples[-1] prior_dist = Normal(*map(torch.stack, zip(*priors))) posterior_dist = Normal(*map(torch.stack, zip(*posteriors))) states = torch.stack(states) posterior_samples = torch.stack(posterior_samples) rec_loss = ( F.mse_loss( bottle(rssm.decoder, states, posterior_samples), x[1:], reduction="none" ) .sum((2, 3, 4)) .mean() ) kld_loss = torch.max( kl_divergence(posterior_dist, prior_dist).sum(-1), free_nats ).mean() rew_loss = F.mse_loss(bottle(rssm.pred_reward, states, posterior_samples), r) optimizer.zero_grad() loss = beta * kld_loss + rec_loss + rew_loss loss.backward() nn.utils.clip_grad_norm_(rssm.parameters(), 1000.0, norm_type=2) optimizer.step() metrics = { "losses": { "kl": kld_loss.item(), "reconstruction": rec_loss.item(), "reward_pred": rew_loss.item(), }, } if grads: metrics["grad_norms"] = { k: 0 if v.grad is None else v.grad.norm().item() for k, v in rssm.named_parameters() } return metrics
[docs] def main(): """Example PlaNet/RSSM training script with rollout collection and evaluation. Builds environment/model/policy objects, iteratively trains on replayed episodes, and periodically saves videos and checkpoints. """ env = None try: env = RolloutGenerator except Exception: pass env = RolloutGenerator env = __import__( "world_models.utils.utils", fromlist=["TorchImageEnvWrapper"] ).TorchImageEnvWrapper("Pendulum-v1", bit_depth=5) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rssm_model = RecurrentStateSpaceModel(env.action_size).to(device) optimizer = torch.optim.Adam(rssm_model.parameters(), lr=1e-3, eps=1e-4) policy = RSSMPolicy( model=rssm_model, planning_horizon=20, num_candidates=1000, num_iterations=10, top_candidates=100, device=device, ) rollout_gen = RolloutGenerator( env, device, policy=policy, episode_gen=lambda: Episode(partial(postprocess_img, depth=5)), max_episode_steps=100, ) mem = Memory(100) mem.append(rollout_gen.rollout_n(1, random_policy=True)) res_dir = "results/" summary = TensorBoardMetrics(f"{res_dir}/") for i in trange(2, desc="Epoch", leave=False): metrics = {} for _ in trange(150, desc="Iter ", leave=False): train_metrics = train(mem, rssm_model.train(), optimizer, device) for k, v in flatten_dict(train_metrics).items(): if k not in metrics: metrics[k] = [] metrics[k].append(v) metrics[f"{k}_mean"] = np.array(metrics[k]).mean() summary.update(metrics) mem.append(rollout_gen.rollout_once(explore=True)) eval_episode, eval_frames, eval_metrics = rollout_gen.rollout_eval() print("\n===== EVAL FRAME DEBUG =====") ef = eval_frames if torch.is_tensor(ef): print("eval_frames is a TORCH tensor, converting to numpy...") ef_np = ef.detach().cpu().numpy() else: ef_np = np.asarray(ef) print("eval_frames.shape =", ef_np.shape) print("dtype =", ef_np.dtype) print("min =", float(ef_np.min()), "max =", float(ef_np.max())) first = ef_np[0] print("first_frame.shape =", first.shape) if first.ndim == 3: print( "channel count =", first.shape[0] if first.shape[0] <= 4 else first.shape[-1], ) print("channel dims =", first.shape) # Print min/max per channel (up to 8 channels) C = first.shape[0] if first.shape[0] <= 8 else first.shape[-1] for c in range(min(C, 8)): ch = first[c] if first.shape[0] <= 8 else first[..., c] print(f"channel[{c}] min={ch.min()} max={ch.max()} mean={ch.mean()}") print("===== END DEBUG =====\n") mem.append(eval_episode) # normalize frames to (T,H,W,3) float in [0,1] before saving safe_frames = normalize_frames_for_saving(eval_frames) save_video(safe_frames, res_dir, f"vid_{i+1}") summary.update(eval_metrics) if (i + 1) % 25 == 0: torch.save(rssm_model.state_dict(), f"{res_dir}/ckpt_{i+1}.pth") if os.getenv("TRAIN_RSSM_DEBUG", "0") == "1": pdb.set_trace()
if __name__ == "__main__": main()