Source code for world_models.models.dreamer

import os
import random
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as distributions

from collections import OrderedDict

import world_models.envs.wrappers as env_wrapper
from world_models.envs.dmc import DeepMindControlEnv
from world_models.envs.gym_env import GymImageEnv
from world_models.envs.unity_env import UnityMLAgentsEnv
from world_models.memory.dreamer_memory import ReplayBuffer
from world_models.models.dreamer_rssm import RSSM
from world_models.vision.dreamer_decoder import ConvDecoder, DenseDecoder, ActionDecoder
from world_models.vision.dreamer_encoder import ConvEncoder
from world_models.utils.dreamer_utils import Logger, FreezeParameters, compute_return
from world_models.configs.dreamer_config import DreamerConfig

os.environ["MUJOCO_GL"] = "egl"


def _resolve_image_size(args):
    size = getattr(args, "image_size", (64, 64))
    if isinstance(size, int):
        return (size, size)
    if isinstance(size, (tuple, list)) and len(size) == 2:
        return (int(size[0]), int(size[1]))
    raise ValueError(f"Invalid image_size={size}. Expected int or (H, W).")


[docs] def make_env(args): """Construct a Dreamer-compatible environment from `DreamerConfig` options. Supports DMC, Gym/Gymnasium, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit. """ size = _resolve_image_size(args) backend = str(getattr(args, "env_backend", "dmc")).lower() env_instance = getattr(args, "env_instance", None) if env_instance is not None: env = GymImageEnv( env_instance, seed=args.seed, size=size, render_mode=getattr(args, "gym_render_mode", "rgb_array"), ) elif backend == "dmc": env = DeepMindControlEnv(args.env, args.seed, size=size) elif backend in {"gym", "gymnasium", "generic"}: env = GymImageEnv( args.env, seed=args.seed, size=size, render_mode=getattr(args, "gym_render_mode", "rgb_array"), ) elif backend in {"unity", "unity_mlagents", "mlagents"}: unity_file_name = getattr(args, "unity_file_name", None) if not unity_file_name: raise ValueError( "unity_file_name must be provided when env_backend='unity_mlagents'." ) env = UnityMLAgentsEnv( file_name=unity_file_name, behavior_name=getattr(args, "unity_behavior_name", None), seed=args.seed, size=size, worker_id=int(getattr(args, "unity_worker_id", 0)), base_port=int(getattr(args, "unity_base_port", 5005)), no_graphics=bool(getattr(args, "unity_no_graphics", True)), time_scale=float(getattr(args, "unity_time_scale", 20.0)), quality_level=int(getattr(args, "unity_quality_level", 1)), max_episode_steps=int(getattr(args, "time_limit", 1000)), ) else: raise ValueError( f"Unknown env_backend='{backend}'. Use one of: dmc, gym, unity_mlagents." ) env = env_wrapper.ActionRepeat(env, int(args.action_repeat)) env = env_wrapper.NormalizeActions(env) repeat = max(1, int(args.action_repeat)) duration = max(1, int(args.time_limit) // repeat) env = env_wrapper.TimeLimit(env, duration) return env
[docs] def preprocess_obs(obs): """Convert raw uint8 image observations to Dreamer float input space. Images are scaled from `[0, 255]` to roughly `[-0.5, 0.5]`, matching the normalization expected by Dreamer encoders. """ obs = obs.to(torch.float32) / 255.0 - 0.5 return obs
[docs] class Dreamer: """Core Dreamer training system combining world model, actor, and value nets. This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O. """ def __init__(self, args, obs_shape, action_size, device, restore=False): self.args = args self.obs_shape = obs_shape self.action_size = action_size self.device = device self.restore = args.restore self.restore_path = args.checkpoint_path self.data_buffer = ReplayBuffer( self.args.buffer_size, self.obs_shape, self.action_size, self.args.train_seq_len, self.args.batch_size, ) self._build_model(restore=self.restore) def _build_model(self, restore): self.rssm = RSSM( action_size=self.action_size, stoch_size=self.args.stoch_size, deter_size=self.args.deter_size, hidden_size=self.args.deter_size, obs_embed_size=self.args.obs_embed_size, activation=self.args.dense_activation_function, ).to(self.device) self.actor = ActionDecoder( action_size=self.action_size, stoch_size=self.args.stoch_size, deter_size=self.args.deter_size, units=self.args.num_units, n_layers=4, activation=self.args.dense_activation_function, ).to(self.device) self.obs_encoder = ConvEncoder( input_shape=self.obs_shape, embed_size=self.args.obs_embed_size, activation=self.args.cnn_activation_function, ).to(self.device) self.obs_decoder = ConvDecoder( stoch_size=self.args.stoch_size, deter_size=self.args.deter_size, output_shape=self.obs_shape, activation=self.args.cnn_activation_function, ).to(self.device) self.reward_model = DenseDecoder( stoch_size=self.args.stoch_size, deter_size=self.args.deter_size, output_shape=(1,), n_layers=2, units=self.args.num_units, activation=self.args.dense_activation_function, dist="normal", ).to(self.device) self.value_model = DenseDecoder( stoch_size=self.args.stoch_size, deter_size=self.args.deter_size, output_shape=(1,), n_layers=3, units=self.args.num_units, activation=self.args.dense_activation_function, dist="normal", ).to(self.device) if self.args.use_disc_model: self.discount_model = DenseDecoder( stoch_size=self.args.stoch_size, deter_size=self.args.deter_size, output_shape=(1,), n_layers=2, units=self.args.num_units, activation=self.args.dense_activation_function, dist="binary", ).to(self.device) if self.args.use_disc_model: self.world_model_params = ( list(self.rssm.parameters()) + list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + list(self.reward_model.parameters()) + list(self.discount_model.parameters()) ) else: self.world_model_params = ( list(self.rssm.parameters()) + list(self.obs_encoder.parameters()) + list(self.obs_decoder.parameters()) + list(self.reward_model.parameters()) ) self.world_model_opt = optim.Adam( self.world_model_params, self.args.model_learning_rate ) self.value_opt = optim.Adam( self.value_model.parameters(), self.args.value_learning_rate ) self.actor_opt = optim.Adam( self.actor.parameters(), self.args.actor_learning_rate ) if self.args.use_disc_model: self.world_model_modules = [ self.rssm, self.obs_encoder, self.obs_decoder, self.reward_model, self.discount_model, ] else: self.world_model_modules = [ self.rssm, self.obs_encoder, self.obs_decoder, self.reward_model, ] self.value_modules = [self.value_model] self.actor_modules = [self.actor] if restore: self.restore_checkpoint(self.restore_path)
[docs] def world_model_loss(self, obs, acs, rews, nonterms): obs = preprocess_obs(obs) obs_embed = self.obs_encoder(obs[1:]) init_state = self.rssm.init_state(self.args.batch_size, self.device) prior, self.posterior = self.rssm.observe_rollout( obs_embed, acs[:-1], nonterms[:-1], init_state, self.args.train_seq_len - 1 ) features = torch.cat([self.posterior["stoch"], self.posterior["deter"]], dim=-1) rew_dist = self.reward_model(features) obs_dist = self.obs_decoder(features) if self.args.use_disc_model: disc_dist = self.discount_model(features) prior_dist = self.rssm.get_dist(prior["mean"], prior["std"]) post_dist = self.rssm.get_dist(self.posterior["mean"], self.posterior["std"]) if self.args.algo == "Dreamerv2": post_no_grad = self.rssm.detach_state(self.posterior) prior_no_grad = self.rssm.detach_state(prior) post_mean_no_grad, post_std_no_grad = ( post_no_grad["mean"], post_no_grad["std"], ) prior_mean_no_grad, prior_std_no_grad = ( prior_no_grad["mean"], prior_no_grad["std"], ) kl_loss = self.args.kl_alpha * ( torch.mean( distributions.kl.kl_divergence( self.rssm.get_dist(post_mean_no_grad, post_std_no_grad), prior_dist, ) ) ) kl_loss += (1 - self.args.kl_alpha) * ( torch.mean( distributions.kl.kl_divergence( post_dist, self.rssm.get_dist(prior_mean_no_grad, prior_std_no_grad), ) ) ) else: kl_loss = torch.mean(distributions.kl.kl_divergence(post_dist, prior_dist)) kl_loss = torch.max( kl_loss, kl_loss.new_full(kl_loss.size(), self.args.free_nats) ) obs_loss = -torch.mean(obs_dist.log_prob(obs[1:])) rew_loss = -torch.mean(rew_dist.log_prob(rews[:-1])) if self.args.use_disc_model: disc_loss = -torch.mean(disc_dist.log_prob(nonterms[:-1])) if self.args.use_disc_model: model_loss = ( self.args.kl_loss_coeff * kl_loss + obs_loss + rew_loss + self.args.disc_loss_coeff * disc_loss ) else: model_loss = self.args.kl_loss_coeff * kl_loss + obs_loss + rew_loss return model_loss
[docs] def actor_loss(self): with torch.no_grad(): posterior = self.rssm.detach_state(self.rssm.seq_to_batch(self.posterior)) with FreezeParameters(self.world_model_modules): imag_states = self.rssm.imagine_rollout( self.actor, posterior, self.args.imagine_horizon ) self.imag_feat = torch.cat([imag_states["stoch"], imag_states["deter"]], dim=-1) with FreezeParameters(self.world_model_modules + self.value_modules): imag_rew_dist = self.reward_model(self.imag_feat) imag_val_dist = self.value_model(self.imag_feat) imag_rews = imag_rew_dist.mean imag_vals = imag_val_dist.mean if self.args.use_disc_model: imag_disc_dist = self.discount_model(self.imag_feat) discounts = imag_disc_dist.mean().detach() else: discounts = self.args.discount * torch.ones_like(imag_rews).detach() self.returns = compute_return( imag_rews[:-1], imag_vals[:-1], discounts[:-1], self.args.td_lambda, imag_vals[-1], ) discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0) self.discounts = torch.cumprod(discounts, 0).detach() actor_loss = -torch.mean(self.discounts * self.returns) return actor_loss
[docs] def value_loss(self): with torch.no_grad(): value_feat = self.imag_feat[:-1].detach() value_targ = self.returns.detach() value_dist = self.value_model(value_feat) value_loss = -torch.mean( self.discounts * value_dist.log_prob(value_targ).unsqueeze(-1) ) return value_loss
[docs] def train_one_batch(self): obs, acs, rews, terms = self.data_buffer.sample() obs = torch.tensor(obs, dtype=torch.float32).to(self.device) acs = torch.tensor(acs, dtype=torch.float32).to(self.device) rews = torch.tensor(rews, dtype=torch.float32).to(self.device).unsqueeze(-1) nonterms = ( torch.tensor((1.0 - terms), dtype=torch.float32) .to(self.device) .unsqueeze(-1) ) model_loss = self.world_model_loss(obs, acs, rews, nonterms) self.world_model_opt.zero_grad() model_loss.backward() nn.utils.clip_grad_norm_(self.world_model_params, self.args.grad_clip_norm) self.world_model_opt.step() actor_loss = self.actor_loss() self.actor_opt.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.grad_clip_norm) self.actor_opt.step() value_loss = self.value_loss() self.value_opt.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_( self.value_model.parameters(), self.args.grad_clip_norm ) self.value_opt.step() return model_loss.item(), actor_loss.item(), value_loss.item()
[docs] def act_with_world_model(self, obs, prev_state, prev_action, explore=False): obs = obs["image"] obs = torch.tensor(obs.copy(), dtype=torch.float32).to(self.device).unsqueeze(0) obs_embed = self.obs_encoder(preprocess_obs(obs)) _, posterior = self.rssm.observe_step(prev_state, prev_action, obs_embed) features = torch.cat([posterior["stoch"], posterior["deter"]], dim=-1) action = self.actor(features, deter=not explore) if explore: action = self.actor.add_exploration(action, self.args.action_noise) return posterior, action
[docs] def act_and_collect_data(self, env, collect_steps): obs = env.reset() done = False prev_state = self.rssm.init_state(1, self.device) prev_action = torch.zeros(1, self.action_size).to(self.device) episode_rewards = [0.0] for i in range(collect_steps): with torch.no_grad(): posterior, action = self.act_with_world_model( obs, prev_state, prev_action, explore=True ) action = action[0].cpu().numpy() next_obs, rew, done, info = env.step(action) executed_action = ( info["action"] if isinstance(info, dict) and ("action" in info) else action ) self.data_buffer.add(obs, executed_action, rew, done) episode_rewards[-1] += rew if done: obs = env.reset() done = False prev_state = self.rssm.init_state(1, self.device) prev_action = torch.zeros(1, self.action_size).to(self.device) if i != collect_steps - 1: episode_rewards.append(0.0) else: obs = next_obs prev_state = posterior prev_action = ( torch.tensor(executed_action, dtype=torch.float32) .to(self.device) .unsqueeze(0) ) return np.array(episode_rewards)
[docs] def evaluate(self, env, eval_episodes, render=False): episode_rew = np.zeros((eval_episodes)) video_images = [[] for _ in range(eval_episodes)] for i in range(eval_episodes): obs = env.reset() done = False prev_state = self.rssm.init_state(1, self.device) prev_action = torch.zeros(1, self.action_size).to(self.device) while not done: with torch.no_grad(): posterior, action = self.act_with_world_model( obs, prev_state, prev_action ) action = action[0].cpu().numpy() next_obs, rew, done, info = env.step(action) executed_action = ( info["action"] if isinstance(info, dict) and ("action" in info) else action ) prev_state = posterior prev_action = ( torch.tensor(executed_action, dtype=torch.float32) .to(self.device) .unsqueeze(0) ) episode_rew[i] += rew if render: video_images[i].append(obs["image"].transpose(1, 2, 0).copy()) obs = next_obs return episode_rew, np.array(video_images[: self.args.max_videos_to_save])
[docs] def collect_random_episodes(self, env, seed_steps): obs = env.reset() done = False seed_episode_rews = [0.0] for i in range(seed_steps): action = env.action_space.sample() next_obs, rew, done, info = env.step(action) executed_action = ( info["action"] if isinstance(info, dict) and ("action" in info) else action ) self.data_buffer.add(obs, executed_action, rew, done) seed_episode_rews[-1] += rew if done: obs = env.reset() if i != seed_steps - 1: seed_episode_rews.append(0.0) done = False else: obs = next_obs return np.array(seed_episode_rews)
[docs] def save(self, save_path): torch.save( { "rssm": self.rssm.state_dict(), "actor": self.actor.state_dict(), "reward_model": self.reward_model.state_dict(), "obs_encoder": self.obs_encoder.state_dict(), "obs_decoder": self.obs_decoder.state_dict(), "discount_model": ( self.discount_model.state_dict() if self.args.use_disc_model else None ), "actor_optimizer": self.actor_opt.state_dict(), "value_optimizer": self.value_opt.state_dict(), "world_model_optimizer": self.world_model_opt.state_dict(), }, save_path, )
[docs] def restore_checkpoint(self, ckpt_path): checkpoint = torch.load(ckpt_path) self.rssm.load_state_dict(checkpoint["rssm"]) self.actor.load_state_dict(checkpoint["actor"]) self.reward_model.load_state_dict(checkpoint["reward_model"]) self.obs_encoder.load_state_dict(checkpoint["obs_encoder"]) self.obs_decoder.load_state_dict(checkpoint["obs_decoder"]) if self.args.use_disc_model and (checkpoint["discount_model"] is not None): self.discount_model.load_state_dict(checkpoint["discount_model"]) self.world_model_opt.load_state_dict(checkpoint["world_model_optimizer"]) self.actor_opt.load_state_dict(checkpoint["actor_optimizer"]) self.value_opt.load_state_dict(checkpoint["value_optimizer"])
[docs] class DreamerAgent: """High-level user API for running Dreamer experiments end to end. It builds environments from config, initializes seeds and logging, instantiates `Dreamer`, and exposes simple `train()` / `evaluate()` methods. """ def __init__(self, config=None, **kwargs): if config is None: self.args = DreamerConfig() else: self.args = config for key, value in kwargs.items(): if hasattr(self.args, key): setattr(self.args, key, value) elif key == "logdir": setattr(self.args, key, value) else: raise ValueError(f"Invalid argument: {key}") data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data/") if not (os.path.exists(data_path)): os.makedirs(data_path) if hasattr(self.args, "logdir") and self.args.logdir is not None: self.logdir = self.args.logdir else: self.logdir = ( self.args.env + "_" + self.args.algo + "_" + self.args.exp_name + "_" + time.strftime("%d-%m-%Y-%H-%M-%S") ) self.logdir = os.path.join(data_path, self.logdir) if not (os.path.exists(self.logdir)): os.makedirs(self.logdir) random.seed(self.args.seed) np.random.seed(self.args.seed) torch.manual_seed(self.args.seed) if torch.cuda.is_available() and not self.args.no_gpu: device = torch.device("cuda") torch.cuda.manual_seed(self.args.seed) else: device = torch.device("cpu") self.train_env = make_env(self.args) self.test_env = make_env(self.args) obs_shape = self.train_env.observation_space["image"].shape action_size = self.train_env.action_space.shape[0] self.dreamer = Dreamer( self.args, obs_shape, action_size, device, self.args.restore ) self.logger = Logger(self.logdir)
[docs] def train(self, total_steps=None): if total_steps is None: total_steps = self.args.total_steps initial_logs = OrderedDict() seed_episode_rews = self.dreamer.collect_random_episodes( self.train_env, self.args.seed_steps // self.args.action_repeat ) global_step = self.dreamer.data_buffer.steps * self.args.action_repeat # without loss of generality intial rews for both train and eval are assumed same initial_logs.update( { "train_avg_reward": np.mean(seed_episode_rews), "train_max_reward": np.max(seed_episode_rews), "train_min_reward": np.min(seed_episode_rews), "train_std_reward": np.std(seed_episode_rews), "eval_avg_reward": np.mean(seed_episode_rews), "eval_max_reward": np.max(seed_episode_rews), "eval_min_reward": np.min(seed_episode_rews), "eval_std_reward": np.std(seed_episode_rews), } ) self.logger.log_scalars(initial_logs, step=0) self.logger.flush() while global_step <= total_steps: print("##################################") print(f"At global step {global_step}") logs = OrderedDict() for _ in range(self.args.update_steps): model_loss, actor_loss, value_loss = self.dreamer.train_one_batch() train_rews = self.dreamer.act_and_collect_data( self.train_env, self.args.collect_steps // self.args.action_repeat ) logs.update( { "model_loss": model_loss, "actor_loss": actor_loss, "value_loss": value_loss, "train_avg_reward": np.mean(train_rews), "train_max_reward": np.max(train_rews), "train_min_reward": np.min(train_rews), "train_std_reward": np.std(train_rews), } ) if global_step % self.args.test_interval == 0: episode_rews, video_images = self.dreamer.evaluate( self.test_env, self.args.test_episodes ) logs.update( { "eval_avg_reward": np.mean(episode_rews), "eval_max_reward": np.max(episode_rews), "eval_min_reward": np.min(episode_rews), "eval_std_reward": np.std(episode_rews), } ) self.logger.log_scalars(logs, global_step) if ( global_step % self.args.log_video_freq == 0 and self.args.log_video_freq != -1 and len(video_images[0]) != 0 ): self.logger.log_video( video_images, global_step, self.args.max_videos_to_save ) if global_step % self.args.checkpoint_interval == 0: ckpt_dir = os.path.join(self.logdir, "ckpts/") if not (os.path.exists(ckpt_dir)): os.makedirs(ckpt_dir) self.dreamer.save(os.path.join(ckpt_dir, f"{global_step}_ckpt.pt")) global_step = self.dreamer.data_buffer.steps * self.args.action_repeat self.logger.flush()
[docs] def evaluate(self): logs = OrderedDict() episode_rews, video_images = self.dreamer.evaluate( self.test_env, self.args.test_episodes, render=True ) logs.update( { "test_avg_reward": np.mean(episode_rews), "test_max_reward": np.max(episode_rews), "test_min_reward": np.min(episode_rews), "test_std_reward": np.std(episode_rews), } ) self.logger.dump_scalars_to_pickle(logs, 0, log_title="test_scalars.pkl") self.logger.log_videos( video_images, 0, max_videos_to_save=self.args.max_videos_to_save )