Source code for world_models.utils.dreamer_utils

import os
import pickle
import torch
import numpy as np
import moviepy as mpy
import cv2

from typing import Iterable
from torch.nn import Module

try:
    import wandb
except ImportError:
    wandb = None


[docs] def get_parameters(modules: Iterable[Module]): """ Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters """ model_parameters = [] for module in modules: model_parameters += list(module.parameters()) return model_parameters
[docs] class FreezeParameters: """Context manager that temporarily disables gradients for given modules. Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness. """ def __init__(self, modules: Iterable[Module]): """ Context manager to locally freeze gradients. In some cases with can speed up computation because gradients aren't calculated for these listed modules. example: ``` with FreezeParameters([module]): output_tensor = module(input_tensor) ``` :param modules: iterable of modules. used to call .parameters() to freeze gradients. """ self.modules = modules self.param_states = [p.requires_grad for p in get_parameters(self.modules)] def __enter__(self): for param in get_parameters(self.modules): param.requires_grad = False def __exit__(self, exc_type, exc_val, exc_tb): for i, param in enumerate(get_parameters(self.modules)): param.requires_grad = self.param_states[i]
[docs] class Logger: """Experiment logger for scalars and GIF rollouts using WandB. Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation. """ def __init__( self, log_dir, enable_wandb=False, wandb_api_key="", wandb_project="torchwm", wandb_entity="", video_format="gif", video_fps=20, ): self._log_dir = log_dir print("########################") print("logging outputs to ", log_dir) print("########################") self._n_logged_samples = 10 self.enable_wandb = enable_wandb self.video_format = video_format self.video_fps = video_fps self._wandb_run = None if self.enable_wandb: if not wandb_api_key: raise ValueError("WandB API key is required when enable_wandb is True") if wandb is None: raise ImportError("wandb is not installed") os.environ["WANDB_API_KEY"] = wandb_api_key self._wandb_run = wandb.init( project=wandb_project, entity=wandb_entity, dir=log_dir, name=os.path.basename(log_dir), )
[docs] def log_scalar(self, scalar, name, step_): if self.enable_wandb and self._wandb_run: self._wandb_run.log({name: scalar}, step=step_)
[docs] def log_scalars(self, scalar_dict, step): for key, value in scalar_dict.items(): print("{} : {}".format(key, value)) self.log_scalar(value, key, step) self.dump_scalars_to_pickle(scalar_dict, step)
[docs] def log_videos( self, videos, step, max_videos_to_save=1, fps=None, video_title="video" ): if fps is None: fps = self.video_fps format = self.video_format # max rollout length max_videos_to_save = np.min([max_videos_to_save, videos.shape[0]]) max_length = videos[0].shape[0] for i in range(max_videos_to_save): if videos[i].shape[0] > max_length: max_length = videos[i].shape[0] # pad rollouts to all be same length for i in range(max_videos_to_save): if videos[i].shape[0] < max_length: padding = np.tile( [videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1) ) videos[i] = np.concatenate([videos[i], padding], 0) if format.lower() == "mp4": # Convert to uint8 HWC BGR for OpenCV video_u8 = (videos[i] * 255).astype(np.uint8) if video_u8.shape[-1] == 3: # RGB to BGR video_u8 = video_u8[..., ::-1] new_video_title = video_title + "{}_{}".format(step, i) + ".mp4" filename = os.path.join(self._log_dir, new_video_title) height, width = video_u8.shape[1], video_u8.shape[2] fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(filename, fourcc, fps, (width, height)) for frame in video_u8: out.write(frame) out.release() else: # gif clip = mpy.ImageSequenceClip(list(videos[i]), fps=fps) new_video_title = video_title + "{}_{}".format(step, i) + ".gif" filename = os.path.join(self._log_dir, new_video_title) clip.write_gif(filename, fps=fps) # Log to WandB if self.enable_wandb and self._wandb_run: # Convert to numpy array for WandB video_array = np.array(videos[i]) # Shape: (T, H, W, C) # WandB expects (T, H, W, C) for Video self._wandb_run.log( {f"{video_title}_{i}": wandb.Video(video_array, fps=fps)}, step=step )
[docs] def dump_scalars_to_pickle(self, metrics, step, log_title=None): log_path = os.path.join( self._log_dir, "scalar_data.pkl" if log_title is None else log_title ) with open(log_path, "ab") as f: pickle.dump({"step": step, **dict(metrics)}, f)
[docs] def flush(self): pass
[docs] def compute_return(rewards, values, discounts, td_lam, last_value): """Compute TD(lambda) returns from imagined rewards, values, and discounts. Implements backward recursion used by Dreamer actor/value objectives. """ next_values = torch.cat([values[1:], last_value.unsqueeze(0)], 0) targets = rewards + discounts * next_values * (1 - td_lam) rets = [] last_rew = last_value for t in range(rewards.shape[0] - 1, -1, -1): last_rew = targets[t] + discounts[t] * td_lam * (last_rew) rets.append(last_rew) returns = torch.flip(torch.stack(rets), [0]) return returns