import os
import random
import time
import numpy as np
import ctypes
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as distributions
from collections import OrderedDict
import world_models.envs.wrappers as env_wrapper
from world_models.envs.dmc import DeepMindControlEnv
from world_models.envs.gym_env import GymImageEnv
from world_models.envs.mujoco_env import make_mujoco_env_from_config
from world_models.envs.robotics_env import make_robotics_env
from world_models.envs.brax_env import BraxImageEnv
from world_models.envs.unity_env import UnityMLAgentsEnv
from world_models.memory.dreamer_memory import ReplayBuffer
from world_models.models.dreamer_rssm import RSSM
from world_models.vision.dreamer_decoder import ConvDecoder, DenseDecoder, ActionDecoder
from world_models.vision.dreamer_encoder import ConvEncoder
from world_models.utils.dreamer_utils import (
Logger,
FreezeParameters,
compute_return,
symlog as _symlog,
symexp as _symexp,
TwoHotEncoder,
)
from world_models.configs.dreamer_config import DreamerConfig
# Only set MUJOCO_GL for non-Windows platforms. On Windows the 'egl' value
# causes mujoco to raise a RuntimeError during import. Respect an existing
# environment value if present.
if os.name != "nt" and os.environ.get("MUJOCO_GL") is None:
os.environ["MUJOCO_GL"] = "egl"
[docs]
def get_available_memory():
"""Get available physical memory in bytes."""
if os.name == "nt": # Windows
class MEMORYSTATUSEX(ctypes.Structure):
_fields_ = [
("dwLength", ctypes.c_ulong),
("dwMemoryLoad", ctypes.c_ulong),
("ullTotalPhys", ctypes.c_ulonglong),
("ullAvailPhys", ctypes.c_ulonglong),
("ullTotalPageFile", ctypes.c_ulonglong),
("ullAvailPageFile", ctypes.c_ulonglong),
("ullTotalVirtual", ctypes.c_ulonglong),
("ullAvailVirtual", ctypes.c_ulonglong),
("ullAvailExtendedVirtual", ctypes.c_ulonglong),
]
memory_status = MEMORYSTATUSEX()
memory_status.dwLength = ctypes.sizeof(MEMORYSTATUSEX)
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
if not kernel32.GlobalMemoryStatusEx(ctypes.byref(memory_status)):
raise OSError("Failed to get memory status")
return memory_status.ullAvailPhys
else: # Linux/Mac
try:
import psutil
return psutil.virtual_memory().available
except ImportError:
# Fallback: read /proc/meminfo on Linux
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if line.startswith("MemAvailable:"):
avail_kb = int(line.split()[1])
return avail_kb * 1024 # Convert KB to bytes
except (FileNotFoundError, ValueError, IndexError):
pass
# Ultimate fallback: assume 8GB available
return 8 * 1024 * 1024 * 1024
def _resolve_image_size(args):
size = getattr(args, "image_size", (64, 64))
if isinstance(size, int):
return (size, size)
if isinstance(size, (tuple, list)) and len(size) == 2:
return (int(size[0]), int(size[1]))
raise ValueError(f"Invalid image_size={size}. Expected int or (H, W).")
[docs]
def make_env(args):
"""Construct a Dreamer-compatible environment from `DreamerConfig` options.
Supports DMC, Gym/Gymnasium, MuJoCo, Gymnasium Robotics, Brax, and Unity ML-Agents backends
and applies the standard wrapper stack: action repeat, action normalization,
and time limit.
"""
size = _resolve_image_size(args)
backend = str(getattr(args, "env_backend", "dmc")).lower()
env_instance = getattr(args, "env_instance", None)
if env_instance is not None:
env = GymImageEnv(
env_instance,
seed=args.seed,
size=size,
render_mode=getattr(args, "gym_render_mode", "rgb_array"),
)
elif backend == "dmc":
env = DeepMindControlEnv(args.env, args.seed, size=size)
elif backend in {"gym", "gymnasium", "generic"}:
env = GymImageEnv(
args.env,
seed=args.seed,
size=size,
render_mode=getattr(args, "gym_render_mode", "rgb_array"),
)
elif backend in {"mujoco", "mjcf", "native_mujoco"}:
env = make_mujoco_env_from_config(args, size)
elif backend in {"robotics", "gymnasium_robotics"}:
env = make_robotics_env(
args.env,
seed=args.seed,
size=size,
render_mode=getattr(args, "gym_render_mode", "rgb_array"),
)
elif backend in {"brax", "jax_brax"}:
env = BraxImageEnv(
args.env,
seed=args.seed,
size=size,
backend=getattr(args, "brax_backend", "generalized"),
episode_length=int(getattr(args, "time_limit", 1000)),
auto_reset=bool(getattr(args, "brax_auto_reset", False)),
jit=bool(getattr(args, "brax_jit", True)),
suppress_warp_warnings=bool(
getattr(args, "brax_suppress_warp_warnings", True)
),
)
elif backend in {"unity", "unity_mlagents", "mlagents"}:
unity_file_name = getattr(args, "unity_file_name", None)
if not unity_file_name:
raise ValueError(
"unity_file_name must be provided when env_backend='unity_mlagents'."
)
env = UnityMLAgentsEnv(
file_name=unity_file_name,
behavior_name=getattr(args, "unity_behavior_name", None),
seed=args.seed,
size=size,
worker_id=int(getattr(args, "unity_worker_id", 0)),
base_port=int(getattr(args, "unity_base_port", 5005)),
no_graphics=bool(getattr(args, "unity_no_graphics", True)),
time_scale=float(getattr(args, "unity_time_scale", 20.0)),
quality_level=int(getattr(args, "unity_quality_level", 1)),
max_episode_steps=int(getattr(args, "time_limit", 1000)),
)
else:
raise ValueError(
f"Unknown env_backend='{backend}'. Use one of: dmc, gym, mujoco, robotics, brax, unity_mlagents."
)
env = env_wrapper.ActionRepeat(env, int(args.action_repeat))
env = env_wrapper.NormalizeActions(env)
repeat = max(1, int(args.action_repeat))
duration = max(1, int(args.time_limit) // repeat)
env = env_wrapper.TimeLimit(env, duration)
return env
[docs]
def preprocess_obs(obs):
"""Convert raw uint8 image observations to Dreamer float input space.
Images are scaled from `[0, 255]` to roughly `[-0.5, 0.5]`, matching the
normalization expected by Dreamer encoders.
"""
obs = obs.to(torch.float32) / 255.0 - 0.5
return obs
[docs]
class Dreamer:
"""Core Dreamer training system combining world model, actor, and value nets.
This class owns model construction, replay sampling, imagination rollouts,
loss computation, optimization steps, evaluation loops, and checkpoint I/O.
"""
def __init__(self, args, obs_shape, action_size, device, restore=False):
self.args = args
self.obs_shape = obs_shape
self.action_size = action_size
self.device = device
self.restore = args.restore
self.restore_path = args.checkpoint_path
# Calculate memory per sample
obs_bytes = np.prod(obs_shape) * 1 # uint8
action_bytes = action_size * 4 # float32
reward_bytes = 4 # float32
terminal_bytes = 4 # float32
bytes_per_sample = obs_bytes + action_bytes + reward_bytes + terminal_bytes
# Get available memory
available_memory = get_available_memory()
# Use 80% of available memory for buffer to leave margin for other processes
max_buffer_size = int((available_memory * 0.8) // bytes_per_sample)
adaptive_buffer_size = min(args.buffer_size, max_buffer_size)
if adaptive_buffer_size < args.buffer_size:
print(
f"Reducing buffer size from {args.buffer_size} to {adaptive_buffer_size} due to memory constraints."
)
self.data_buffer = ReplayBuffer(
adaptive_buffer_size,
self.obs_shape,
self.action_size,
self.args.train_seq_len,
self.args.batch_size,
)
self._build_model(restore=self.restore)
def _build_model(self, restore):
self.rssm = RSSM(
action_size=self.action_size,
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
hidden_size=self.args.deter_size,
obs_embed_size=self.args.obs_embed_size,
activation=self.args.dense_activation_function,
).to(self.device)
self.actor = ActionDecoder(
action_size=self.action_size,
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
units=self.args.num_units,
n_layers=4,
activation=self.args.dense_activation_function,
).to(self.device)
self.obs_encoder = ConvEncoder(
input_shape=self.obs_shape,
embed_size=self.args.obs_embed_size,
activation=self.args.cnn_activation_function,
).to(self.device)
self.obs_decoder = ConvDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=self.obs_shape,
activation=self.args.cnn_activation_function,
).to(self.device)
head_dist = "symlog_twohot" if self.args.algo == "Dreamerv2" else "normal"
head_kwargs = {}
if head_dist == "symlog_twohot":
head_kwargs = {
"num_buckets": getattr(self.args, "num_buckets", 255),
"symlog_range": getattr(self.args, "symlog_range", 10.0),
}
self.reward_model = DenseDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=(1,),
n_layers=2,
units=self.args.num_units,
activation=self.args.dense_activation_function,
dist=head_dist,
**head_kwargs,
).to(self.device)
self.value_model = DenseDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=(1,),
n_layers=3,
units=self.args.num_units,
activation=self.args.dense_activation_function,
dist=head_dist,
**head_kwargs,
).to(self.device)
if self.args.use_disc_model:
self.discount_model = DenseDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=(1,),
n_layers=2,
units=self.args.num_units,
activation=self.args.dense_activation_function,
dist="binary",
).to(self.device)
if self.args.use_disc_model:
self.world_model_params = (
list(self.rssm.parameters())
+ list(self.obs_encoder.parameters())
+ list(self.obs_decoder.parameters())
+ list(self.reward_model.parameters())
+ list(self.discount_model.parameters())
)
else:
self.world_model_params = (
list(self.rssm.parameters())
+ list(self.obs_encoder.parameters())
+ list(self.obs_decoder.parameters())
+ list(self.reward_model.parameters())
)
self.world_model_opt = optim.Adam(
self.world_model_params, self.args.model_learning_rate
)
self.value_opt = optim.Adam(
self.value_model.parameters(), self.args.value_learning_rate
)
self.actor_opt = optim.Adam(
self.actor.parameters(), self.args.actor_learning_rate
)
if self.args.use_disc_model:
self.world_model_modules = [
self.rssm,
self.obs_encoder,
self.obs_decoder,
self.reward_model,
self.discount_model,
]
else:
self.world_model_modules = [
self.rssm,
self.obs_encoder,
self.obs_decoder,
self.reward_model,
]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor]
if restore:
self.restore_checkpoint(self.restore_path)
[docs]
def world_model_loss(self, obs, acs, rews, nonterms):
obs = preprocess_obs(obs)
obs_embed = self.obs_encoder(obs[1:])
init_state = self.rssm.init_state(self.args.batch_size, self.device)
prior, self.posterior = self.rssm.observe_rollout(
obs_embed, acs[:-1], nonterms[:-1], init_state, self.args.train_seq_len - 1
)
features = torch.cat([self.posterior["stoch"], self.posterior["deter"]], dim=-1)
rew_dist = self.reward_model(features)
obs_dist = self.obs_decoder(features)
disc_dist = None
if self.args.use_disc_model:
disc_dist = self.discount_model(features)
prior_dist = self.rssm.get_dist(prior["mean"], prior["std"])
post_dist = self.rssm.get_dist(self.posterior["mean"], self.posterior["std"])
if self.args.algo == "Dreamerv2":
post_no_grad = self.rssm.detach_state(self.posterior)
prior_no_grad = self.rssm.detach_state(prior)
post_mean_no_grad, post_std_no_grad = (
post_no_grad["mean"],
post_no_grad["std"],
)
prior_mean_no_grad, prior_std_no_grad = (
prior_no_grad["mean"],
prior_no_grad["std"],
)
kl_loss = self.args.kl_alpha * (
torch.mean(
distributions.kl.kl_divergence(
self.rssm.get_dist(post_mean_no_grad, post_std_no_grad),
prior_dist,
)
)
)
kl_loss += (1 - self.args.kl_alpha) * (
torch.mean(
distributions.kl.kl_divergence(
post_dist,
self.rssm.get_dist(prior_mean_no_grad, prior_std_no_grad),
)
)
)
else:
kl_loss = torch.mean(distributions.kl.kl_divergence(post_dist, prior_dist))
kl_loss = torch.max(
kl_loss, kl_loss.new_full(kl_loss.size(), self.args.free_nats)
)
if self.args.algo == "Dreamerv2":
target_rew = _symlog(rews[:-1])
else:
target_rew = rews[:-1]
obs_loss = -torch.mean(obs_dist.log_prob(obs[1:]))
rew_loss = -torch.mean(rew_dist.log_prob(target_rew))
model_loss = self.args.kl_loss_coeff * kl_loss + obs_loss + rew_loss
if self.args.use_disc_model and disc_dist is not None:
disc_loss = -torch.mean(disc_dist.log_prob(nonterms[:-1]))
model_loss = model_loss + self.args.disc_loss_coeff * disc_loss
return model_loss
[docs]
def actor_loss(self):
with torch.no_grad():
posterior = self.rssm.detach_state(self.rssm.seq_to_batch(self.posterior))
with FreezeParameters(self.world_model_modules):
imag_states = self.rssm.imagine_rollout(
self.actor, posterior, self.args.imagine_horizon
)
self.imag_feat = torch.cat([imag_states["stoch"], imag_states["deter"]], dim=-1)
with FreezeParameters(self.world_model_modules + self.value_modules):
imag_rew_dist = self.reward_model(self.imag_feat)
imag_val_dist = self.value_model(self.imag_feat)
imag_rews = imag_rew_dist.mean
imag_vals = imag_val_dist.mean
if self.args.use_disc_model:
imag_disc_dist = self.discount_model(self.imag_feat)
discounts = imag_disc_dist.mean().detach()
else:
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
self.returns = compute_return(
imag_rews[:-1],
imag_vals[:-1],
discounts[:-1],
self.args.td_lambda,
imag_vals[-1],
)
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
if self.args.algo == "Dreamerv2":
weight = self.discounts.detach()
target = _symlog(self.returns.detach())
actor_loss = -torch.mean(weight * target)
else:
actor_loss = -torch.mean(self.discounts * self.returns)
return actor_loss
[docs]
def value_loss(self):
with torch.no_grad():
value_feat = self.imag_feat[:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
if self.args.algo == "Dreamerv2":
target = _symlog(value_targ)
log_prob = value_dist.log_prob(target)
else:
log_prob = value_dist.log_prob(value_targ).unsqueeze(-1)
value_loss = -torch.mean(self.discounts * log_prob)
return value_loss
[docs]
def train_one_batch(self):
obs, acs, rews, terms = self.data_buffer.sample()
obs = torch.tensor(obs, dtype=torch.float32).to(self.device)
acs = torch.tensor(acs, dtype=torch.float32).to(self.device)
rews = torch.tensor(rews, dtype=torch.float32).to(self.device).unsqueeze(-1)
nonterms = (
torch.tensor((1.0 - terms), dtype=torch.float32)
.to(self.device)
.unsqueeze(-1)
)
model_loss = self.world_model_loss(obs, acs, rews, nonterms)
self.world_model_opt.zero_grad()
model_loss.backward()
nn.utils.clip_grad_norm_(self.world_model_params, self.args.grad_clip_norm)
self.world_model_opt.step()
actor_loss = self.actor_loss()
self.actor_opt.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.grad_clip_norm)
self.actor_opt.step()
value_loss = self.value_loss()
self.value_opt.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(
self.value_model.parameters(), self.args.grad_clip_norm
)
self.value_opt.step()
return model_loss.item(), actor_loss.item(), value_loss.item()
[docs]
def act_with_world_model(self, obs, prev_state, prev_action, explore=False):
obs = obs["image"]
obs = torch.tensor(obs.copy(), dtype=torch.float32).to(self.device).unsqueeze(0)
obs_embed = self.obs_encoder(preprocess_obs(obs))
_, posterior = self.rssm.observe_step(prev_state, prev_action, obs_embed)
features = torch.cat([posterior["stoch"], posterior["deter"]], dim=-1)
action = self.actor(features, deter=not explore)
if explore:
action = self.actor.add_exploration(action, self.args.action_noise)
return posterior, action
[docs]
def act_and_collect_data(self, env, collect_steps):
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
episode_rewards = [0.0]
for i in range(collect_steps):
with torch.no_grad():
posterior, action = self.act_with_world_model(
obs, prev_state, prev_action, explore=True
)
action = action[0].cpu().numpy()
next_obs, rew, done, info = env.step(action)
executed_action = (
info["action"]
if isinstance(info, dict) and ("action" in info)
else action
)
self.data_buffer.add(obs, executed_action, rew, done)
episode_rewards[-1] += rew
if done:
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
if i != collect_steps - 1:
episode_rewards.append(0.0)
else:
obs = next_obs
prev_state = posterior
prev_action = (
torch.tensor(executed_action, dtype=torch.float32)
.to(self.device)
.unsqueeze(0)
)
return np.array(episode_rewards)
[docs]
def evaluate(self, env, eval_episodes, render=False):
episode_rew = np.zeros((eval_episodes))
video_images = [[] for _ in range(eval_episodes)]
latents = [] if render else None
for i in range(eval_episodes):
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
while not done:
with torch.no_grad():
posterior, action = self.act_with_world_model(
obs, prev_state, prev_action
)
action = action[0].cpu().numpy()
next_obs, rew, done, info = env.step(action)
executed_action = (
info["action"]
if isinstance(info, dict) and ("action" in info)
else action
)
prev_state = posterior
prev_action = (
torch.tensor(executed_action, dtype=torch.float32)
.to(self.device)
.unsqueeze(0)
)
episode_rew[i] += rew
if render:
video_images[i].append(obs["image"].transpose(1, 2, 0).copy())
if latents is not None:
latents.append(
torch.cat([posterior[0], posterior[1]], dim=-1)
.cpu()
.numpy()
)
obs = next_obs
if latents is not None and len(latents) > 0:
latents = np.array(latents)
return (
episode_rew,
np.array(video_images[: self.args.max_videos_to_save]),
latents,
)
[docs]
def collect_random_episodes(self, env, seed_steps):
obs = env.reset()
done = False
seed_episode_rews = [0.0]
for i in range(seed_steps):
action = env.action_space.sample()
next_obs, rew, done, info = env.step(action)
executed_action = (
info["action"]
if isinstance(info, dict) and ("action" in info)
else action
)
self.data_buffer.add(obs, executed_action, rew, done)
seed_episode_rews[-1] += rew
if done:
obs = env.reset()
if i != seed_steps - 1:
seed_episode_rews.append(0.0)
done = False
else:
obs = next_obs
return np.array(seed_episode_rews)
[docs]
def save(self, save_path):
torch.save(
{
"rssm": self.rssm.state_dict(),
"actor": self.actor.state_dict(),
"reward_model": self.reward_model.state_dict(),
"obs_encoder": self.obs_encoder.state_dict(),
"obs_decoder": self.obs_decoder.state_dict(),
"discount_model": (
self.discount_model.state_dict()
if self.args.use_disc_model
else None
),
"actor_optimizer": self.actor_opt.state_dict(),
"value_optimizer": self.value_opt.state_dict(),
"world_model_optimizer": self.world_model_opt.state_dict(),
},
save_path,
)
[docs]
def restore_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path)
self.rssm.load_state_dict(checkpoint["rssm"])
self.actor.load_state_dict(checkpoint["actor"])
self.reward_model.load_state_dict(checkpoint["reward_model"])
self.obs_encoder.load_state_dict(checkpoint["obs_encoder"])
self.obs_decoder.load_state_dict(checkpoint["obs_decoder"])
if self.args.use_disc_model and (checkpoint["discount_model"] is not None):
self.discount_model.load_state_dict(checkpoint["discount_model"])
self.world_model_opt.load_state_dict(checkpoint["world_model_optimizer"])
self.actor_opt.load_state_dict(checkpoint["actor_optimizer"])
self.value_opt.load_state_dict(checkpoint["value_optimizer"])
[docs]
class DreamerAgent:
"""High-level user API for running Dreamer experiments end to end.
It builds environments from config, initializes seeds and logging,
instantiates `Dreamer`, and exposes simple `train()` / `evaluate()` methods.
"""
def __init__(self, config=None, **kwargs):
if config is None:
self.args = DreamerConfig()
else:
self.args = config
self.last_latents_ref = kwargs.get("last_latents_ref", None)
for key, value in kwargs.items():
if hasattr(self.args, key):
setattr(self.args, key, value)
elif key == "logdir":
# Accept either `logdir` (server/legacy) and mirror into
# both `logdir` and `log_dir` so downstream code using either
# naming convention picks up the value.
setattr(self.args, "logdir", value)
try:
setattr(self.args, "log_dir", value)
except Exception:
pass
elif key == "last_latents_ref":
self.last_latents_ref = value
else:
raise ValueError(f"Invalid argument: {key}")
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data/")
if not (os.path.exists(data_path)):
os.makedirs(data_path)
# Allow caller to pass an absolute `logdir`. If provided and absolute,
# use it verbatim. Otherwise keep the historical behavior of creating
# a subdir under the package data path so examples/tests remain stable.
if hasattr(self.args, "logdir") and self.args.logdir is not None:
self.logdir = self.args.logdir
else:
self.logdir = (
self.args.env
+ "_"
+ self.args.algo
+ "_"
+ self.args.exp_name
+ "_"
+ time.strftime("%d-%m-%Y-%H-%M-%S")
)
# If `self.logdir` is not an absolute path, place it under package data_path
if not os.path.isabs(self.logdir):
self.logdir = os.path.join(data_path, self.logdir)
if not (os.path.exists(self.logdir)):
os.makedirs(self.logdir)
random.seed(self.args.seed)
np.random.seed(self.args.seed)
torch.manual_seed(self.args.seed)
if torch.cuda.is_available() and not self.args.no_gpu:
device = torch.device("cuda")
torch.cuda.manual_seed(self.args.seed)
else:
device = torch.device("cpu")
self.train_env = make_env(self.args)
self.test_env = make_env(self.args)
obs_shape = self.train_env.observation_space["image"].shape
action_size = self.train_env.action_space.shape[0]
self.dreamer = Dreamer(
self.args, obs_shape, action_size, device, self.args.restore
)
self.logger = Logger(
self.logdir,
self.args.enable_wandb,
self.args.wandb_api_key,
self.args.wandb_project,
self.args.wandb_entity,
self.args.video_format,
self.args.video_fps,
)
[docs]
def train(self, total_steps=None):
if total_steps is None:
total_steps = self.args.total_steps
initial_logs = OrderedDict()
seed_episode_rews = self.dreamer.collect_random_episodes(
self.train_env, self.args.seed_steps // self.args.action_repeat
)
global_step = self.dreamer.data_buffer.steps * self.args.action_repeat
# without loss of generality intial rews for both train and eval are assumed same
initial_logs.update(
{
"train_avg_reward": np.mean(seed_episode_rews),
"train_max_reward": np.max(seed_episode_rews),
"train_min_reward": np.min(seed_episode_rews),
"train_std_reward": np.std(seed_episode_rews),
"eval_avg_reward": np.mean(seed_episode_rews),
"eval_max_reward": np.max(seed_episode_rews),
"eval_min_reward": np.min(seed_episode_rews),
"eval_std_reward": np.std(seed_episode_rews),
}
)
self.logger.log_scalars(initial_logs, step=0)
self.logger.flush()
while global_step <= total_steps:
print("##################################")
print(f"At global step {global_step}")
logs = OrderedDict()
for _ in range(self.args.update_steps):
model_loss, actor_loss, value_loss = self.dreamer.train_one_batch()
train_rews = self.dreamer.act_and_collect_data(
self.train_env, self.args.collect_steps // self.args.action_repeat
)
logs.update(
{
"model_loss": model_loss,
"actor_loss": actor_loss,
"value_loss": value_loss,
"train_avg_reward": np.mean(train_rews),
"train_max_reward": np.max(train_rews),
"train_min_reward": np.min(train_rews),
"train_std_reward": np.std(train_rews),
}
)
if global_step % self.args.test_interval == 0:
episode_rews, video_images, latents = self.dreamer.evaluate(
self.test_env, self.args.test_episodes
)
if self.last_latents_ref is not None and latents is not None:
self.last_latents_ref[0] = latents
logs.update(
{
"eval_avg_reward": np.mean(episode_rews),
"eval_max_reward": np.max(episode_rews),
"eval_min_reward": np.min(episode_rews),
"eval_std_reward": np.std(episode_rews),
}
)
self.logger.log_scalars(logs, global_step)
if (
global_step % self.args.log_video_freq == 0
and self.args.log_video_freq != -1
and len(video_images[0]) != 0
):
self.logger.log_video(
video_images, global_step, self.args.max_videos_to_save
)
if global_step % self.args.checkpoint_interval == 0:
ckpt_dir = os.path.join(self.logdir, "ckpts/")
if not (os.path.exists(ckpt_dir)):
os.makedirs(ckpt_dir)
self.dreamer.save(os.path.join(ckpt_dir, f"{global_step}_ckpt.pt"))
global_step = self.dreamer.data_buffer.steps * self.args.action_repeat
self.logger.flush()
[docs]
def evaluate(self):
logs = OrderedDict()
episode_rews, video_images, latents = self.dreamer.evaluate(
self.test_env, self.args.test_episodes, render=True
)
if self.last_latents_ref is not None and latents is not None:
self.last_latents_ref[0] = latents
logs.update(
{
"test_avg_reward": np.mean(episode_rews),
"test_max_reward": np.max(episode_rews),
"test_min_reward": np.min(episode_rews),
"test_std_reward": np.std(episode_rews),
}
)
self.logger.dump_scalars_to_pickle(logs, 0, log_title="test_scalars.pkl")
self.logger.log_videos(
video_images, 0, max_videos_to_save=self.args.max_videos_to_save
)
self.logger.flush()
return episode_rews, video_images, latents