Source code for world_models.controller.rollout_generator

import numpy as np
import torch
from collections import defaultdict

from tqdm import trange
from torchvision.utils import make_grid

from world_models.memory.planet_memory import Episode
from world_models.utils.utils import StreamingVideoWriter


[docs] class RolloutGenerator: """Rollout generator class.""" def __init__( self, env, device, policy=None, max_episode_steps=None, episode_gen=None, name=None, enable_streaming_video=False, streaming_video_path=None, streaming_video_fps=20, streaming_video_format="mp4", ): self.env = env self.device = device self.policy = policy self.episode_gen = episode_gen or Episode self.name = name or "Rollout Generator" self.max_episode_steps = max_episode_steps if self.max_episode_steps is None: self.max_episode_steps = self.env.max_episode_steps self.enable_streaming_video = enable_streaming_video self.streaming_video_path = streaming_video_path self.streaming_video_fps = streaming_video_fps self.streaming_video_format = streaming_video_format self.video_writer = None
[docs] def rollout_once(self, random_policy=False, explore=False) -> Episode: """Performs a single rollout of an environment given a policy and returns and episode instance. """ if self.policy is None and not random_policy: random_policy = True print("Policy is None. Using random policy instead!!") if not random_policy: self.policy.reset() eps = self.episode_gen() obs = self.env.reset() des = f"{self.name} Ts" for _ in trange(self.max_episode_steps, desc=des, leave=False): if random_policy: act = self.env.sample_random_action() else: act = self.policy.poll(obs.to(self.device), explore).flatten() nobs, reward, terminal, _ = self.env.step(act) eps.append(obs, act, reward, terminal) obs = nobs eps.terminate(nobs) return eps
[docs] def rollout_n(self, n=1, random_policy=False) -> list[Episode]: """ Performs n rollouts. """ if self.policy is None and not random_policy: random_policy = True print("Policy is None. Using random policy instead!!") des = f"{self.name} EPS" ret = [] for _ in trange(n, desc=des, leave=False): ret.append(self.rollout_once(random_policy=random_policy)) return ret
[docs] def rollout_eval_n(self, n): metrics = defaultdict(list) episodes, frames = [], [] for _ in range(n): e, f, m = self.rollout_eval() episodes.append(e) frames.append(f) for k, v in m.items(): metrics[k].append(v) return episodes, frames, metrics
[docs] def rollout_eval(self, collect_latents=False): assert self.policy is not None, "Policy is None!!" self.policy.reset() eps = self.episode_gen() obs = self.env.reset() des = f"{self.name} Eval Ts" frames = [] latents_list = [] if collect_latents else None if self.enable_streaming_video and self.streaming_video_path: self.video_writer = StreamingVideoWriter( self.streaming_video_path, fps=self.streaming_video_fps, format=self.streaming_video_format, ) metrics = {} rec_losses = [] pred_r, act_r = [], [] eps_reward = 0 for _ in trange(self.max_episode_steps, desc=des, leave=False): with torch.no_grad(): act = self.policy.poll(obs.to(self.device)).flatten() dec = ( self.policy.rssm.decoder(self.policy.h, self.policy.s) .squeeze() .cpu() .clamp_(-0.5, 0.5) ) rec_losses.append(((obs - dec).abs()).sum().item()) frame = make_grid([obs + 0.5, dec + 0.5], nrow=2).numpy() frames.append(frame) if self.video_writer: self.video_writer.write_frame(frame) if collect_latents: latents_list.append( torch.cat([self.policy.h, self.policy.s], dim=-1).cpu().numpy() ) pred_r.append( self.policy.rssm.pred_reward(self.policy.h, self.policy.s) .cpu() .flatten() .item() ) nobs, reward, terminal, _ = self.env.step(act) eps.append(obs, act, reward, terminal) act_r.append(reward) eps_reward += reward obs = nobs eps.terminate(nobs) if self.video_writer: self.video_writer.close() metrics["eval/episode_reward"] = eps_reward metrics["eval/reconstruction_loss"] = rec_losses metrics["eval/reward_pred_loss"] = abs( np.array(act_r)[:-1] - np.array(pred_r)[1:] ) return eps, np.stack(frames), metrics, latents_list