import os
import random
import sys
import time
import importlib.util
import numpy as np
import ctypes
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as distributions
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Union
from typing import Any
import world_models.envs.wrappers as env_wrapper
from world_models.envs.dmc import DeepMindControlEnv
from world_models.envs.gym_env import GymImageEnv
from world_models.envs.mujoco_env import make_mujoco_env_from_config
from world_models.envs.procgen_env import ProcgenImageEnv
from world_models.envs.robotics_env import make_robotics_env
from world_models.envs.brax_env import BraxImageEnv
from world_models.envs.dmlab import DMLabEnv
from world_models.envs.bsuite_env import BSuiteImageEnv
from world_models.envs.unity_env import UnityMLAgentsEnv
from world_models.memory.dreamer_memory import ReplayBuffer
from world_models.models.dreamer_rssm import RSSM
from world_models.vision.dreamer_decoder import ConvDecoder, DenseDecoder, ActionDecoder
from world_models.vision.dreamer_encoder import ConvEncoder
from world_models.utils.dreamer_utils import (
Logger,
FreezeParameters,
compute_return,
symlog as _symlog,
symexp as _symexp,
TwoHotEncoder,
)
from world_models.configs.dreamer_config import DreamerConfig
from world_models.export import ExportableAgentMixin
from world_models.utils.logging_utils import (
assert_finite,
collect_system_stats,
get_package_logger,
setup_logging,
)
logger = get_package_logger(__name__)
# Only set MUJOCO_GL for non-Windows platforms. On Windows the 'egl' value
# causes mujoco to raise a RuntimeError during import. Respect an existing
# environment value if present.
if os.name != "nt" and os.environ.get("MUJOCO_GL") is None:
os.environ["MUJOCO_GL"] = "egl"
[docs]
def get_available_memory() -> int:
"""Get available physical memory in bytes."""
if os.name == "nt": # Windows
class MEMORYSTATUSEX(ctypes.Structure):
_fields_ = [
("dwLength", ctypes.c_ulong),
("dwMemoryLoad", ctypes.c_ulong),
("ullTotalPhys", ctypes.c_ulonglong),
("ullAvailPhys", ctypes.c_ulonglong),
("ullTotalPageFile", ctypes.c_ulonglong),
("ullAvailPageFile", ctypes.c_ulonglong),
("ullTotalVirtual", ctypes.c_ulonglong),
("ullAvailVirtual", ctypes.c_ulonglong),
("ullAvailExtendedVirtual", ctypes.c_ulonglong),
]
memory_status = MEMORYSTATUSEX()
memory_status.dwLength = ctypes.sizeof(MEMORYSTATUSEX)
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
if not kernel32.GlobalMemoryStatusEx(ctypes.byref(memory_status)):
raise OSError("Failed to get memory status")
return memory_status.ullAvailPhys
else: # Linux/Mac
try:
import psutil
return psutil.virtual_memory().available
except ImportError:
# Fallback: read /proc/meminfo on Linux
try:
with open("/proc/meminfo", "r") as f:
for line in f:
if line.startswith("MemAvailable:"):
avail_kb = int(line.split()[1])
return avail_kb * 1024 # Convert KB to bytes
except (FileNotFoundError, ValueError, IndexError):
pass
# Ultimate fallback: assume 8GB available
return 8 * 1024 * 1024 * 1024
def _resolve_image_size(args: Any) -> tuple:
size = getattr(args, "image_size", (64, 64))
if isinstance(size, int):
return (size, size)
if isinstance(size, (tuple, list)) and len(size) == 2:
return (int(size[0]), int(size[1]))
raise ValueError(f"Invalid image_size={size}. Expected int or (H, W).")
[docs]
def make_env(args: Any) -> Any:
"""Construct a Dreamer-compatible environment from `DreamerConfig` options.
Supports DMC, DMLab, Gym/Gymnasium, MuJoCo, Gymnasium Robotics, Procgen,
Brax, BSuite, and Unity ML-Agents backends
and applies the standard wrapper stack: action repeat, action normalization,
and time limit.
"""
size = _resolve_image_size(args)
backend = str(getattr(args, "env_backend", "dmc")).lower()
env_instance = getattr(args, "env_instance", None)
if env_instance is not None:
env: Any = GymImageEnv(
env_instance,
seed=args.seed,
size=size,
render_mode=getattr(args, "gym_render_mode", "rgb_array"),
)
elif backend == "dmc":
env = DeepMindControlEnv(args.env, args.seed, size=size)
elif backend in {"dmlab", "deepmind_lab", "deepmindlab"}:
env = DMLabEnv(
args.env,
seed=args.seed,
size=size,
action_repeat=int(getattr(args, "dmlab_action_repeat", 4)),
action_set=getattr(args, "dmlab_action_set", None),
observations=getattr(args, "dmlab_observations", None),
config=getattr(args, "dmlab_config", None),
renderer=getattr(args, "dmlab_renderer", "hardware"),
)
elif backend in {"gym", "gymnasium", "generic"}:
env = GymImageEnv(
args.env,
seed=args.seed,
size=size,
render_mode=getattr(args, "gym_render_mode", "rgb_array"),
)
elif backend in {"mujoco", "mjcf", "native_mujoco"}:
env = make_mujoco_env_from_config(args, size)
elif backend in {"procgen", "coinrun"}:
env = ProcgenImageEnv(
args.env,
seed=args.seed,
size=size,
distribution_mode=getattr(args, "procgen_distribution_mode", "easy"),
num_levels=int(getattr(args, "procgen_num_levels", 0)),
start_level=getattr(args, "procgen_start_level", None),
max_episode_steps=int(getattr(args, "time_limit", 1000)),
)
elif backend in {"robotics", "gymnasium_robotics"}:
env = make_robotics_env(
args.env,
seed=args.seed,
size=size,
render_mode=getattr(args, "gym_render_mode", "rgb_array"),
)
elif backend in {"bsuite", "behavior_suite", "behaviour_suite"}:
env = BSuiteImageEnv(
args.env,
seed=args.seed,
size=size,
)
elif backend in {"brax", "jax_brax"}:
env = BraxImageEnv(
args.env,
seed=args.seed,
size=size,
backend=getattr(args, "brax_backend", "generalized"),
episode_length=int(getattr(args, "time_limit", 1000)),
auto_reset=bool(getattr(args, "brax_auto_reset", False)),
jit=bool(getattr(args, "brax_jit", True)),
suppress_warp_warnings=bool(
getattr(args, "brax_suppress_warp_warnings", True)
),
)
elif backend in {"unity", "unity_mlagents", "mlagents"}:
unity_file_name = getattr(args, "unity_file_name", None)
if not unity_file_name:
raise ValueError(
"unity_file_name must be provided when env_backend='unity_mlagents'."
)
env = UnityMLAgentsEnv( # type: ignore[no-untyped-call]
file_name=unity_file_name,
behavior_name=getattr(args, "unity_behavior_name", None),
seed=args.seed,
size=size,
worker_id=int(getattr(args, "unity_worker_id", 0)),
base_port=int(getattr(args, "unity_base_port", 5005)),
no_graphics=bool(getattr(args, "unity_no_graphics", True)),
time_scale=float(getattr(args, "unity_time_scale", 20.0)),
quality_level=int(getattr(args, "unity_quality_level", 1)),
max_episode_steps=int(getattr(args, "time_limit", 1000)),
)
else:
raise ValueError(
f"Unknown env_backend='{backend}'. Use one of: dmc, dmlab, gym, mujoco, "
"robotics, procgen, bsuite, brax, unity_mlagents."
)
env = env_wrapper.ActionRepeat(env, int(args.action_repeat))
env = env_wrapper.NormalizeActions(env)
repeat = max(1, int(args.action_repeat))
duration = max(1, int(args.time_limit) // repeat)
env = env_wrapper.TimeLimit(env, duration)
return env
def _coerce_dreamer_config(config: Any | None) -> DreamerConfig:
"""Normalize Dreamer config inputs to a ``DreamerConfig`` instance."""
if config is None:
return DreamerConfig()
if isinstance(config, DreamerConfig):
return config
if isinstance(config, dict):
return DreamerConfig.from_dict(config)
if isinstance(config, (str, Path)):
return DreamerConfig.from_yaml(config)
raise TypeError(
"config must be a DreamerConfig, dict, YAML path/string, or None; "
f"got {type(config).__name__}."
)
def _apply_config_overrides(
config: DreamerConfig, overrides: dict[str, Any]
) -> DreamerConfig:
for key, value in overrides.items():
if not hasattr(config, key):
raise ValueError(f"Invalid DreamerConfig override: {key}")
setattr(config, key, value)
return config
def _default_device(config: DreamerConfig) -> torch.device:
if torch.cuda.is_available() and not config.no_gpu:
return torch.device("cuda")
return torch.device("cpu")
def _find_local_pretrained_file(path: Path, candidates: tuple[str, ...]) -> Path | None:
if path.is_file():
return path
if path.is_dir():
for name in candidates:
candidate = path / name
if candidate.exists():
return candidate
for pattern in ("*.pt", "*.pth", "*.bin"):
matches = sorted(path.glob(pattern))
if matches:
return matches[0]
return None
def _resolve_pretrained_file(
pretrained_model_name_or_path: str | Path,
candidates: tuple[str, ...],
*,
repo_type: str | None = None,
revision: str | None = None,
) -> Path | None:
local_path = Path(pretrained_model_name_or_path)
local_file = _find_local_pretrained_file(local_path, candidates)
if local_file is not None:
return local_file
if importlib.util.find_spec("huggingface_hub") is None:
return None
from huggingface_hub import hf_hub_download
repo_id = str(pretrained_model_name_or_path)
for filename in candidates:
try:
downloaded = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type=repo_type,
revision=revision,
)
except Exception:
continue
return Path(downloaded)
return None
def _save_config_next_to_checkpoint(
config: DreamerConfig, checkpoint_path: str | Path
) -> None:
checkpoint = Path(checkpoint_path)
checkpoint.parent.mkdir(parents=True, exist_ok=True)
config.to_yaml(checkpoint.parent / "config.yaml")
[docs]
def preprocess_obs(obs: torch.Tensor) -> torch.Tensor:
"""Convert raw uint8 image observations to Dreamer float input space.
Images are scaled from `[0, 255]` to roughly `[-0.5, 0.5]`, matching the
normalization expected by Dreamer encoders.
"""
obs = obs.to(torch.float32) / 255.0 - 0.5
return obs
[docs]
class Dreamer:
"""Core Dreamer training system combining world model, actor, and value nets.
This class owns model construction, replay sampling, imagination rollouts,
loss computation, optimization steps, evaluation loops, and checkpoint I/O.
"""
def __init__(
self,
args: Any,
obs_shape: Any,
action_size: int,
device: torch.device | str,
restore: bool = False,
) -> None:
self.args = args
self.obs_shape = obs_shape
self.action_size = action_size
self.device = device
self.restore = args.restore
self.restore_path = args.checkpoint_path
# Calculate memory per sample
obs_bytes = np.prod(obs_shape) * 1 # uint8
action_bytes = action_size * 4 # float32
reward_bytes = 4 # float32
terminal_bytes = 4 # float32
bytes_per_sample = obs_bytes + action_bytes + reward_bytes + terminal_bytes
# Get available memory
available_memory = get_available_memory()
# Use 80% of available memory for buffer to leave margin for other processes
max_buffer_size = int((available_memory * 0.8) // bytes_per_sample)
adaptive_buffer_size = min(args.buffer_size, max_buffer_size)
if adaptive_buffer_size < args.buffer_size:
logger.warning(
"Reducing buffer size from %s to %s due to memory constraints.",
args.buffer_size,
adaptive_buffer_size,
)
self.data_buffer = ReplayBuffer(
adaptive_buffer_size,
self.obs_shape,
self.action_size,
self.args.train_seq_len,
self.args.batch_size,
)
self.use_amp = bool(
getattr(args, "use_amp", True)
and getattr(device, "type", str(device)) == "cuda"
)
self._build_model(restore=self.restore)
def _build_model(self, restore: bool = False) -> None:
self.rssm = RSSM(
action_size=self.action_size,
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
hidden_size=self.args.deter_size,
obs_embed_size=self.args.obs_embed_size,
activation=self.args.dense_activation_function,
).to(self.device)
self.actor = ActionDecoder(
action_size=self.action_size,
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
units=self.args.num_units,
n_layers=4,
activation=self.args.dense_activation_function,
).to(self.device)
self.obs_encoder = ConvEncoder(
input_shape=self.obs_shape,
embed_size=self.args.obs_embed_size,
activation=self.args.cnn_activation_function,
).to(self.device)
self.obs_decoder = ConvDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=self.obs_shape,
activation=self.args.cnn_activation_function,
).to(self.device)
head_dist = "symlog_twohot" if self.args.algo == "Dreamerv2" else "normal"
head_kwargs: Dict[str, Any] = {}
if head_dist == "symlog_twohot":
head_kwargs = {
"num_buckets": getattr(self.args, "num_buckets", 255),
"symlog_range": getattr(self.args, "symlog_range", 10.0),
}
self.reward_model = DenseDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=(1,),
n_layers=2,
units=self.args.num_units,
activation=self.args.dense_activation_function,
dist=head_dist,
**head_kwargs,
).to(self.device)
self.value_model = DenseDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=(1,),
n_layers=3,
units=self.args.num_units,
activation=self.args.dense_activation_function,
dist=head_dist,
**head_kwargs,
).to(self.device)
if self.args.use_disc_model:
self.discount_model = DenseDecoder(
stoch_size=self.args.stoch_size,
deter_size=self.args.deter_size,
output_shape=(1,),
n_layers=2,
units=self.args.num_units,
activation=self.args.dense_activation_function,
dist="binary",
).to(self.device)
if self.args.use_disc_model:
self.world_model_params = (
list(self.rssm.parameters())
+ list(self.obs_encoder.parameters())
+ list(self.obs_decoder.parameters())
+ list(self.reward_model.parameters())
+ list(self.discount_model.parameters())
)
else:
self.world_model_params = (
list(self.rssm.parameters())
+ list(self.obs_encoder.parameters())
+ list(self.obs_decoder.parameters())
+ list(self.reward_model.parameters())
)
self.world_model_opt = optim.Adam(
self.world_model_params, self.args.model_learning_rate
)
self.value_opt = optim.Adam(
self.value_model.parameters(), self.args.value_learning_rate
)
self.actor_opt = optim.Adam(
self.actor.parameters(), self.args.actor_learning_rate
)
self.world_model_scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
self.actor_scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
self.value_scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp)
if self.args.use_disc_model:
self.world_model_modules = [
self.rssm,
self.obs_encoder,
self.obs_decoder,
self.reward_model,
self.discount_model,
]
else:
self.world_model_modules = [
self.rssm,
self.obs_encoder,
self.obs_decoder,
self.reward_model,
]
self.value_modules = [self.value_model]
self.actor_modules = [self.actor]
if restore:
self.restore_checkpoint(self.restore_path)
[docs]
@classmethod
def from_config(
cls,
config: DreamerConfig | dict[str, Any] | str | Path | None = None,
*,
obs_shape: tuple[int, ...] | None = None,
action_size: int | None = None,
device: str | torch.device | None = None,
restore: bool | None = None,
**overrides: Any,
) -> "Dreamer":
"""Build a core Dreamer model from a config object, dict, or YAML file.
``obs_shape`` and ``action_size`` may be supplied directly. When either
is omitted, this method constructs a temporary environment from the
config to infer the model shapes.
"""
args = _apply_config_overrides(_coerce_dreamer_config(config), overrides)
if obs_shape is None or action_size is None:
env = make_env(args)
if obs_shape is None:
obs_shape = tuple(env.observation_space["image"].shape)
if action_size is None:
action_size = int(env.action_space.shape[0])
torch_device = (
torch.device(device) if device is not None else _default_device(args)
)
should_restore = args.restore if restore is None else restore
return cls(args, obs_shape, action_size, torch_device, should_restore)
[docs]
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | Path,
*,
config: DreamerConfig | dict[str, Any] | str | Path | None = None,
checkpoint_filename: str | None = None,
config_filename: str = "config.yaml",
repo_type: str | None = None,
revision: str | None = None,
map_location: str | torch.device | None = None,
**overrides: Any,
) -> "Dreamer":
"""Load a Dreamer checkpoint from a local path/directory or the HF Hub."""
checkpoint_candidates = (
(checkpoint_filename,)
if checkpoint_filename is not None
else ("model.pt", "pytorch_model.bin", "checkpoint.pt", "ckpt.pt")
)
checkpoint_path = _resolve_pretrained_file(
pretrained_model_name_or_path,
checkpoint_candidates,
repo_type=repo_type,
revision=revision,
)
if checkpoint_path is None:
raise FileNotFoundError(
f"Could not find a Dreamer checkpoint for {pretrained_model_name_or_path!r}."
)
checkpoint = torch.load(
checkpoint_path, map_location=map_location or "cpu", weights_only=True
)
checkpoint_config = (
checkpoint.get("config") if isinstance(checkpoint, dict) else None
)
if config is None and checkpoint_config is not None:
args = _coerce_dreamer_config(checkpoint_config)
elif config is not None:
args = _coerce_dreamer_config(config)
else:
config_path = _resolve_pretrained_file(
pretrained_model_name_or_path,
(config_filename, "dreamer_config.yaml", "config.yml"),
repo_type=repo_type,
revision=revision,
)
if config_path is None:
raise FileNotFoundError(
"No config was provided and no config YAML was found beside "
f"{pretrained_model_name_or_path!r}."
)
args = DreamerConfig.from_yaml(config_path)
args = _apply_config_overrides(args, overrides)
obs_shape = (
checkpoint.get("obs_shape") if isinstance(checkpoint, dict) else None
)
action_size = (
checkpoint.get("action_size") if isinstance(checkpoint, dict) else None
)
model = cls.from_config(
args,
obs_shape=tuple(obs_shape) if obs_shape is not None else None,
action_size=int(action_size) if action_size is not None else None,
device=map_location,
restore=False,
)
model.restore_checkpoint(checkpoint_path, map_location=map_location)
return model
[docs]
def parameter_count(self, trainable_only: bool = False) -> int:
"""Return the total number of parameters owned by the Dreamer modules."""
return sum(
param.numel()
for module in (
self.world_model_modules + self.actor_modules + self.value_modules
)
for param in module.parameters()
if not trainable_only or param.requires_grad
)
[docs]
def summary(self) -> dict[str, Any]:
"""Return a compact parameter-count summary for the Dreamer modules."""
modules = {
"rssm": self.rssm,
"actor": self.actor,
"reward_model": self.reward_model,
"obs_encoder": self.obs_encoder,
"obs_decoder": self.obs_decoder,
"value_model": self.value_model,
}
if self.args.use_disc_model:
modules["discount_model"] = self.discount_model
module_params = {
name: sum(param.numel() for param in module.parameters())
for name, module in modules.items()
}
trainable_params = {
name: sum(
param.numel() for param in module.parameters() if param.requires_grad
)
for name, module in modules.items()
}
return {
"total_parameters": sum(module_params.values()),
"trainable_parameters": sum(trainable_params.values()),
"modules": module_params,
"trainable_modules": trainable_params,
}
[docs]
@assert_finite
def world_model_loss(
self,
obs: torch.Tensor,
acs: torch.Tensor,
rews: torch.Tensor,
nonterms: torch.Tensor,
) -> torch.Tensor:
obs = preprocess_obs(obs)
obs_embed = self.obs_encoder(obs[1:])
init_state = self.rssm.init_state(self.args.batch_size, self.device)
prior, self.posterior = self.rssm.observe_rollout(
obs_embed, acs[:-1], nonterms[:-1], init_state, self.args.train_seq_len - 1
)
features = torch.cat([self.posterior["stoch"], self.posterior["deter"]], dim=-1)
rew_dist = self.reward_model(features)
obs_dist = self.obs_decoder(features)
disc_dist = None
if self.args.use_disc_model:
disc_dist = self.discount_model(features)
prior_dist = self.rssm.get_dist(prior["mean"], prior["std"])
post_dist = self.rssm.get_dist(self.posterior["mean"], self.posterior["std"])
if self.args.algo == "Dreamerv2":
post_no_grad = self.rssm.detach_state(self.posterior)
prior_no_grad = self.rssm.detach_state(prior)
post_mean_no_grad, post_std_no_grad = (
post_no_grad["mean"],
post_no_grad["std"],
)
prior_mean_no_grad, prior_std_no_grad = (
prior_no_grad["mean"],
prior_no_grad["std"],
)
kl_loss = self.args.kl_alpha * (
torch.mean(
distributions.kl.kl_divergence(
self.rssm.get_dist(post_mean_no_grad, post_std_no_grad),
prior_dist,
)
)
)
kl_loss += (1 - self.args.kl_alpha) * (
torch.mean(
distributions.kl.kl_divergence(
post_dist,
self.rssm.get_dist(prior_mean_no_grad, prior_std_no_grad),
)
)
)
else:
kl_loss = torch.mean(distributions.kl.kl_divergence(post_dist, prior_dist))
kl_loss = torch.max(
kl_loss, kl_loss.new_full(kl_loss.size(), self.args.free_nats)
)
if self.args.algo == "Dreamerv2":
target_rew = _symlog(rews[:-1])
else:
target_rew = rews[:-1]
obs_loss = -torch.mean(obs_dist.log_prob(obs[1:]))
rew_loss = -torch.mean(rew_dist.log_prob(target_rew))
model_loss = self.args.kl_loss_coeff * kl_loss + obs_loss + rew_loss
if self.args.use_disc_model and disc_dist is not None:
disc_loss = -torch.mean(disc_dist.log_prob(nonterms[:-1]))
model_loss = model_loss + self.args.disc_loss_coeff * disc_loss
return model_loss
[docs]
@assert_finite
def actor_loss(self) -> torch.Tensor:
with torch.no_grad():
posterior = self.rssm.detach_state(self.rssm.seq_to_batch(self.posterior))
with FreezeParameters(self.world_model_modules):
imag_states = self.rssm.imagine_rollout(
self.actor, posterior, self.args.imagine_horizon
)
self.imag_feat = torch.cat([imag_states["stoch"], imag_states["deter"]], dim=-1)
with FreezeParameters(
list(self.world_model_modules) + list(self.value_modules)
):
imag_rew_dist = self.reward_model(self.imag_feat)
imag_val_dist = self.value_model(self.imag_feat)
imag_rews = imag_rew_dist.mean
imag_vals = imag_val_dist.mean
if self.args.use_disc_model:
imag_disc_dist = self.discount_model(self.imag_feat)
discounts = imag_disc_dist.mean().detach()
else:
discounts = self.args.discount * torch.ones_like(imag_rews).detach()
self.returns = compute_return(
imag_rews[:-1],
imag_vals[:-1],
discounts[:-1],
self.args.td_lambda,
imag_vals[-1],
)
discounts = torch.cat([torch.ones_like(discounts[:1]), discounts[1:-1]], 0)
self.discounts = torch.cumprod(discounts, 0).detach()
if self.args.algo == "Dreamerv2":
weight = self.discounts.detach()
target = _symlog(self.returns.detach())
actor_loss = -torch.mean(weight * target)
else:
actor_loss = -torch.mean(self.discounts * self.returns)
return actor_loss
[docs]
@assert_finite
def value_loss(self) -> torch.Tensor:
with torch.no_grad():
value_feat = self.imag_feat[:-1].detach()
value_targ = self.returns.detach()
value_dist = self.value_model(value_feat)
if self.args.algo == "Dreamerv2":
target = _symlog(value_targ)
log_prob = value_dist.log_prob(target)
else:
log_prob = value_dist.log_prob(value_targ).unsqueeze(-1)
value_loss = -torch.mean(self.discounts * log_prob)
return value_loss
[docs]
def train_one_batch(self) -> list[float]:
obs, acs, rews, terms = self.data_buffer.sample()
obs_t = torch.tensor(obs, dtype=torch.float32).to(self.device)
acs_t = torch.tensor(acs, dtype=torch.float32).to(self.device)
rews_t = torch.tensor(rews, dtype=torch.float32).to(self.device).unsqueeze(-1)
nonterms = (
torch.tensor((1.0 - terms), dtype=torch.float32)
.to(self.device)
.unsqueeze(-1)
)
with torch.amp.autocast(
device_type=getattr(self.device, "type", str(self.device)),
enabled=self.use_amp,
):
model_loss = self.world_model_loss(obs_t, acs_t, rews_t, nonterms)
self.world_model_opt.zero_grad(set_to_none=True)
self.world_model_scaler.scale(model_loss).backward()
self.world_model_scaler.unscale_(self.world_model_opt)
nn.utils.clip_grad_norm_(self.world_model_params, self.args.grad_clip_norm)
self.world_model_scaler.step(self.world_model_opt)
self.world_model_scaler.update()
with torch.amp.autocast(
device_type=getattr(self.device, "type", str(self.device)),
enabled=self.use_amp,
):
actor_loss = self.actor_loss()
self.actor_opt.zero_grad(set_to_none=True)
self.actor_scaler.scale(actor_loss).backward()
self.actor_scaler.unscale_(self.actor_opt)
nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.grad_clip_norm)
self.actor_scaler.step(self.actor_opt)
self.actor_scaler.update()
with torch.amp.autocast(
device_type=getattr(self.device, "type", str(self.device)),
enabled=self.use_amp,
):
value_loss = self.value_loss()
self.value_opt.zero_grad(set_to_none=True)
self.value_scaler.scale(value_loss).backward()
self.value_scaler.unscale_(self.value_opt)
nn.utils.clip_grad_norm_(
self.value_model.parameters(), self.args.grad_clip_norm
)
self.value_scaler.step(self.value_opt)
self.value_scaler.update()
return (
torch.stack([model_loss.detach(), actor_loss.detach(), value_loss.detach()])
.cpu()
.tolist()
)
[docs]
def act_with_world_model(
self,
obs: Any,
prev_state: Any,
prev_action: torch.Tensor,
explore: bool = False,
) -> tuple:
obs = obs["image"]
obs = torch.tensor(obs.copy(), dtype=torch.float32).to(self.device).unsqueeze(0)
obs_embed = self.obs_encoder(preprocess_obs(obs))
_, posterior = self.rssm.observe_step(prev_state, prev_action, obs_embed)
features = torch.cat([posterior["stoch"], posterior["deter"]], dim=-1)
action = self.actor(features, deter=not explore)
if explore:
action = self.actor.add_exploration(action, self.args.action_noise)
return posterior, action
[docs]
def act_and_collect_data(self, env: Any, collect_steps: int) -> np.ndarray:
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
episode_rewards = [0.0]
for i in range(collect_steps):
with torch.no_grad():
posterior, action = self.act_with_world_model(
obs, prev_state, prev_action, explore=True
)
action = action[0].cpu().numpy()
next_obs, rew, done, info = env.step(action)
executed_action = (
info["action"]
if isinstance(info, dict) and ("action" in info)
else action
)
self.data_buffer.add(obs, executed_action, rew, done)
episode_rewards[-1] += rew
if done:
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
if i != collect_steps - 1:
episode_rewards.append(0.0)
else:
obs = next_obs
prev_state = posterior
prev_action = (
torch.tensor(executed_action, dtype=torch.float32)
.to(self.device)
.unsqueeze(0)
)
return np.array(episode_rewards)
[docs]
def evaluate(self, env: Any, eval_episodes: int, render: bool = False) -> tuple:
episode_rew = np.zeros((eval_episodes))
video_images: list[list[Any]] = [[] for _ in range(eval_episodes)]
latents: list[Any] | None = [] if render else None
for i in range(eval_episodes):
obs = env.reset()
done = False
prev_state = self.rssm.init_state(1, self.device)
prev_action = torch.zeros(1, self.action_size).to(self.device)
while not done:
with torch.no_grad():
posterior, action = self.act_with_world_model(
obs, prev_state, prev_action
)
action = action[0].cpu().numpy()
next_obs, rew, done, info = env.step(action)
executed_action = (
info["action"]
if isinstance(info, dict) and ("action" in info)
else action
)
prev_state = posterior
prev_action = (
torch.tensor(executed_action, dtype=torch.float32)
.to(self.device)
.unsqueeze(0)
)
episode_rew[i] += rew
if render:
video_images[i].append(obs["image"].transpose(1, 2, 0).copy())
if latents is not None:
latents.append(
torch.cat([posterior[0], posterior[1]], dim=-1)
.cpu()
.numpy()
)
obs = next_obs
latents_arr: Any = None
if latents is not None and len(latents) > 0:
latents_arr = np.array(latents)
return (
episode_rew,
np.array(video_images[: self.args.max_videos_to_save]),
latents_arr,
)
[docs]
def collect_random_episodes(self, env: Any, seed_steps: int) -> np.ndarray:
obs = env.reset()
done = False
seed_episode_rews = [0.0]
for i in range(seed_steps):
action = env.action_space.sample()
next_obs, rew, done, info = env.step(action)
executed_action = (
info["action"]
if isinstance(info, dict) and ("action" in info)
else action
)
self.data_buffer.add(obs, executed_action, rew, done)
seed_episode_rews[-1] += rew
if done:
obs = env.reset()
if i != seed_steps - 1:
seed_episode_rews.append(0.0)
done = False
else:
obs = next_obs
return np.array(seed_episode_rews)
[docs]
def save(self, save_path: str) -> None:
_save_config_next_to_checkpoint(self.args, save_path)
torch.save(
{
"config": self.args.to_dict(),
"obs_shape": tuple(self.obs_shape),
"action_size": int(self.action_size),
"rssm": self.rssm.state_dict(),
"actor": self.actor.state_dict(),
"reward_model": self.reward_model.state_dict(),
"obs_encoder": self.obs_encoder.state_dict(),
"obs_decoder": self.obs_decoder.state_dict(),
"discount_model": (
self.discount_model.state_dict()
if self.args.use_disc_model
else None
),
"actor_optimizer": self.actor_opt.state_dict(),
"value_optimizer": self.value_opt.state_dict(),
"world_model_optimizer": self.world_model_opt.state_dict(),
},
save_path,
)
[docs]
def restore_checkpoint(
self, ckpt_path: str | Path, map_location: Any = None
) -> None:
checkpoint = torch.load(ckpt_path, map_location=map_location, weights_only=True)
self.rssm.load_state_dict(checkpoint["rssm"])
self.actor.load_state_dict(checkpoint["actor"])
self.reward_model.load_state_dict(checkpoint["reward_model"])
self.obs_encoder.load_state_dict(checkpoint["obs_encoder"])
self.obs_decoder.load_state_dict(checkpoint["obs_decoder"])
if self.args.use_disc_model and (checkpoint["discount_model"] is not None):
self.discount_model.load_state_dict(checkpoint["discount_model"])
self.world_model_opt.load_state_dict(checkpoint["world_model_optimizer"])
self.actor_opt.load_state_dict(checkpoint["actor_optimizer"])
self.value_opt.load_state_dict(checkpoint["value_optimizer"])
[docs]
class DreamerAgent(ExportableAgentMixin):
"""High-level user API for running Dreamer experiments end to end.
It builds environments from config, initializes seeds and logging,
instantiates `Dreamer`, and exposes simple `train()` / `evaluate()` methods.
"""
def __init__(self, config: Any = None, **kwargs: Any) -> None:
self.args = _coerce_dreamer_config(config)
self.last_latents_ref = kwargs.get("last_latents_ref", None)
for key, value in kwargs.items():
if hasattr(self.args, key):
setattr(self.args, key, value)
elif key == "logdir":
# Accept either `logdir` (server/legacy) and mirror into
# both `logdir` and `log_dir` so downstream code using either
# naming convention picks up the value.
setattr(self.args, "logdir", value)
try:
setattr(self.args, "log_dir", value)
except Exception:
pass
elif key == "last_latents_ref":
self.last_latents_ref = value
elif key == "data_path":
# Backwards-compatible alias for the new configurable data_dir.
setattr(self.args, "data_dir", value)
else:
raise ValueError(f"Invalid argument: {key}")
data_path = (
getattr(self.args, "data_dir", None)
or os.environ.get("TORCHWM_DATA_DIR")
or getattr(self.args, "log_dir", None)
or "runs"
)
data_path = os.path.abspath(os.path.expanduser(data_path))
if not (os.path.exists(data_path)):
os.makedirs(data_path)
# Allow caller to pass an absolute `logdir`. If provided and absolute,
# use it verbatim. Otherwise keep the historical behavior of creating
# a subdir under the package data path so examples/tests remain stable.
if hasattr(self.args, "logdir") and self.args.logdir is not None:
self.logdir = self.args.logdir
else:
self.logdir = (
self.args.env
+ "_"
+ self.args.algo
+ "_"
+ self.args.exp_name
+ "_"
+ time.strftime("%d-%m-%Y-%H-%M-%S")
)
# If `self.logdir` is not an absolute path, place it under the
# configurable data directory instead of the package source tree.
if not os.path.isabs(self.logdir):
self.logdir = os.path.join(data_path, self.logdir)
if not (os.path.exists(self.logdir)):
os.makedirs(self.logdir)
self.args.to_yaml(os.path.join(self.logdir, "config.yaml"))
setup_logging(
"world_models",
getattr(self.args, "log_level", "INFO"),
getattr(self.args, "log_file", None),
)
if getattr(self.args, "detect_anomaly", False):
torch.autograd.set_detect_anomaly(True)
logger.warning("torch.autograd anomaly detection is enabled")
random.seed(self.args.seed)
np.random.seed(self.args.seed)
torch.manual_seed(self.args.seed)
if torch.cuda.is_available() and not self.args.no_gpu:
device = torch.device("cuda")
torch.cuda.manual_seed(self.args.seed)
else:
device = torch.device("cpu")
print("WARNING: CUDA not available, using CPU")
self.train_env = make_env(self.args)
self.test_env = make_env(self.args)
obs_shape = self.train_env.observation_space["image"].shape
action_size = self.train_env.action_space.shape[0]
self.dreamer = Dreamer(
self.args, obs_shape, action_size, device, self.args.restore
)
self.logger = Logger(
self.logdir,
enable_wandb=self.args.enable_wandb,
wandb_api_key=self.args.wandb_api_key,
wandb_project=self.args.wandb_project,
wandb_entity=self.args.wandb_entity,
video_format=self.args.video_format,
video_fps=self.args.video_fps,
enable_tensorboard=getattr(self.args, "enable_tensorboard", False),
enable_console=getattr(self.args, "enable_console_metrics", True),
enable_jsonl=getattr(self.args, "enable_jsonl", True),
jsonl_filename=getattr(self.args, "jsonl_filename", "metrics.jsonl"),
)
[docs]
@classmethod
def from_config(
cls,
config: DreamerConfig | dict[str, Any] | str | Path | None = None,
**overrides: Any,
) -> "DreamerAgent":
"""Build a high-level Dreamer agent from a config object, dict, or YAML file."""
return cls(_apply_config_overrides(_coerce_dreamer_config(config), overrides))
[docs]
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | Path,
*,
config: DreamerConfig | dict[str, Any] | str | Path | None = None,
checkpoint_filename: str | None = None,
config_filename: str = "config.yaml",
repo_type: str | None = None,
revision: str | None = None,
map_location: str | torch.device | None = None,
**overrides: Any,
) -> "DreamerAgent":
"""Create a Dreamer agent and restore weights from a local path or HF Hub."""
checkpoint_candidates = (
(checkpoint_filename,)
if checkpoint_filename is not None
else ("model.pt", "pytorch_model.bin", "checkpoint.pt", "ckpt.pt")
)
checkpoint_path = _resolve_pretrained_file(
pretrained_model_name_or_path,
checkpoint_candidates,
repo_type=repo_type,
revision=revision,
)
if checkpoint_path is None:
raise FileNotFoundError(
f"Could not find a Dreamer checkpoint for {pretrained_model_name_or_path!r}."
)
checkpoint = torch.load(
checkpoint_path, map_location=map_location or "cpu", weights_only=True
)
checkpoint_config = (
checkpoint.get("config") if isinstance(checkpoint, dict) else None
)
if config is None and checkpoint_config is not None:
args = _coerce_dreamer_config(checkpoint_config)
elif config is not None:
args = _coerce_dreamer_config(config)
else:
config_path = _resolve_pretrained_file(
pretrained_model_name_or_path,
(config_filename, "dreamer_config.yaml", "config.yml"),
repo_type=repo_type,
revision=revision,
)
if config_path is None:
raise FileNotFoundError(
"No config was provided and no config YAML was found beside "
f"{pretrained_model_name_or_path!r}."
)
args = DreamerConfig.from_yaml(config_path)
args = _apply_config_overrides(args, overrides)
args.restore = False
agent = cls(args)
agent.dreamer.restore_checkpoint(checkpoint_path, map_location=map_location)
return agent
[docs]
def parameter_count(self, trainable_only: bool = False) -> int:
"""Return the total number of Dreamer parameters."""
return self.dreamer.parameter_count(trainable_only=trainable_only)
[docs]
def summary(self) -> dict[str, Any]:
"""Return a compact parameter-count summary for the wrapped Dreamer model."""
return self.dreamer.summary()
[docs]
def train(self, total_steps: int | None = None) -> None:
if total_steps is None:
total_steps = self.args.total_steps
initial_logs = OrderedDict()
seed_episode_rews = self.dreamer.collect_random_episodes(
self.train_env, self.args.seed_steps // self.args.action_repeat
)
global_step = self.dreamer.data_buffer.steps * self.args.action_repeat
# without loss of generality intial rews for both train and eval are assumed same
initial_logs.update(
{
"train_avg_reward": np.mean(seed_episode_rews),
"train_max_reward": np.max(seed_episode_rews),
"train_min_reward": np.min(seed_episode_rews),
"train_std_reward": np.std(seed_episode_rews),
"eval_avg_reward": np.mean(seed_episode_rews),
"eval_max_reward": np.max(seed_episode_rews),
"eval_min_reward": np.min(seed_episode_rews),
"eval_std_reward": np.std(seed_episode_rews),
}
)
self.logger.log_scalars(initial_logs, step=0)
self.logger.flush()
while global_step <= total_steps:
logger.info("At global step %s", global_step)
logs = OrderedDict()
video_images = []
update_start = time.perf_counter()
for _ in range(self.args.update_steps):
model_loss, actor_loss, value_loss = self.dreamer.train_one_batch()
update_elapsed = max(time.perf_counter() - update_start, 1e-9)
collect_count = self.args.collect_steps // self.args.action_repeat
collect_start = time.perf_counter()
train_rews = self.dreamer.act_and_collect_data(
self.train_env, collect_count
)
collect_elapsed = max(time.perf_counter() - collect_start, 1e-9)
logs.update(
{
"model_loss": model_loss,
"actor_loss": actor_loss,
"value_loss": value_loss,
"train_avg_reward": np.mean(train_rews),
"train_max_reward": np.max(train_rews),
"train_min_reward": np.min(train_rews),
"train_std_reward": np.std(train_rews),
"throughput/env_steps_per_sec": (
collect_count * self.args.action_repeat
)
/ collect_elapsed,
"throughput/grad_steps_per_sec": self.args.update_steps
/ update_elapsed,
"replay/fill_ratio": min(
1.0,
self.dreamer.data_buffer.steps
/ max(1, self.dreamer.data_buffer.size),
),
"replay/steps": self.dreamer.data_buffer.steps,
"replay/episodes": self.dreamer.data_buffer.episodes,
}
)
stats_freq = getattr(self.args, "log_system_stats_freq", 0)
if stats_freq and global_step % stats_freq == 0:
logs.update(collect_system_stats(self.dreamer.device))
if global_step % self.args.test_interval == 0:
episode_rews, video_images, latents = self.dreamer.evaluate(
self.test_env, self.args.test_episodes
)
if self.last_latents_ref is not None and latents is not None:
self.last_latents_ref[0] = latents
logs.update(
{
"eval_avg_reward": np.mean(episode_rews),
"eval_max_reward": np.max(episode_rews),
"eval_min_reward": np.min(episode_rews),
"eval_std_reward": np.std(episode_rews),
}
)
self.logger.log_scalars(logs, global_step)
if (
global_step % self.args.log_video_freq == 0
and self.args.log_video_freq != -1
and len(video_images) > 0
and len(video_images[0]) != 0
):
self.logger.log_videos(
video_images, global_step, self.args.max_videos_to_save
)
if global_step % self.args.checkpoint_interval == 0:
ckpt_dir = os.path.join(self.logdir, "ckpts/")
if not (os.path.exists(ckpt_dir)):
os.makedirs(ckpt_dir)
self.dreamer.save(os.path.join(ckpt_dir, f"{global_step}_ckpt.pt"))
global_step = self.dreamer.data_buffer.steps * self.args.action_repeat
self.logger.flush()
[docs]
def evaluate(self) -> tuple:
logs = OrderedDict()
episode_rews, video_images, latents = self.dreamer.evaluate(
self.test_env, self.args.test_episodes, render=True
)
if self.last_latents_ref is not None and latents is not None:
self.last_latents_ref[0] = latents
logs.update(
{
"test_avg_reward": np.mean(episode_rews),
"test_max_reward": np.max(episode_rews),
"test_min_reward": np.min(episode_rews),
"test_std_reward": np.std(episode_rews),
}
)
self.logger.dump_scalars_to_pickle(logs, 0, log_title="test_scalars.pkl")
self.logger.log_videos(
video_images, 0, max_videos_to_save=self.args.max_videos_to_save
)
self.logger.flush()
return episode_rews, video_images, latents