Source code for world_models.benchmarks.adapters

from __future__ import annotations

from typing import Any, List

import numpy as np

from world_models.configs.diamond_config import DiamondConfig
from world_models.configs.iris_config import IRISConfig
from world_models.training.train_diamond import DiamondAgent
from world_models.training.train_iris import IRISTrainer
from world_models.configs.dreamer_config import DreamerConfig
from world_models.models.dreamer import DreamerAgent


[docs] class BaseAdapter: def __init__(self, env_spec: Any | None = None, seed: int = 0, **kwargs): self.env_spec = env_spec self.seed = seed
[docs] def load_checkpoint(self, path: str): raise NotImplementedError
[docs] def evaluate(self, num_episodes: int = 1, render: bool = False): """Return standardized output. Preferred format: dict with key 'episode_returns' -> List[float]""" raise NotImplementedError
[docs] class DiamondAdapter(BaseAdapter): def __init__(self, env_spec: Any | None = None, seed: int = 0, **kwargs): super().__init__(env_spec, seed) # env_spec can be a dict with keys like 'game', or a simple string game name if isinstance(env_spec, dict): game = env_spec.get("game", None) elif isinstance(env_spec, str): game = env_spec else: game = getattr(env_spec, "game", None) preset = kwargs.get("preset", None) device = kwargs.get("device", "cuda") cfg = DiamondConfig(preset=preset) if game: cfg.game = game cfg.seed = seed cfg.device = device self.agent = DiamondAgent(cfg)
[docs] def load_checkpoint(self, path: str): try: self.agent.load_checkpoint(path) except Exception: raise
[docs] def evaluate(self, num_episodes: int = 1, render: bool = False): # DiamondAgent.evaluate returns a float mean reward per episode by default. # To produce per-episode returns we run single-episode evaluations repeatedly. episode_returns: List[float] = [] for _ in range(num_episodes): try: r = self.agent.evaluate(num_episodes=1) # If evaluate returns an average scalar, treat as single-episode reward if isinstance(r, (int, float)): episode_returns.append(float(r)) elif isinstance(r, dict) and "episode_returns" in r: vals = r["episode_returns"] if isinstance(vals, (list, tuple)) and len(vals) > 0: episode_returns.extend([float(v) for v in vals]) else: # fallback: use mean episode_returns.append(float(np.mean(vals) if vals else 0.0)) else: # Unknown format: try to coerce try: episode_returns.append(float(r)) except Exception: episode_returns.append(0.0) except Exception: # If per-episode evaluation fails, append zero as safe fallback episode_returns.append(0.0) return {"episode_returns": episode_returns}
[docs] class IRISAdapter(BaseAdapter): def __init__(self, env_spec: Any | None = None, seed: int = 0, **kwargs): super().__init__(env_spec, seed) game = None if isinstance(env_spec, dict): game = env_spec.get("game") elif isinstance(env_spec, str): game = env_spec device = kwargs.get("device", "cuda") config = kwargs.get("config", None) if config is None: cfg = IRISConfig() else: cfg = config # IRISTrainer accepts (game, device, seed, config) self.trainer = IRISTrainer( game=game or cfg.env, device=device, seed=seed, config=cfg )
[docs] def load_checkpoint(self, path: str): # IRIS agent provides save/load on agent; delegate if trainer has agent try: # IRISAgent implements load(path) if hasattr(self.trainer, "agent") and hasattr(self.trainer.agent, "load"): self.trainer.agent.load(path) else: raise AttributeError("IRIS agent does not expose load(path)") except Exception: raise
[docs] def evaluate(self, num_episodes: int = 1, render: bool = False): res = self.trainer.evaluate(num_episodes=num_episodes, render=render) if render: # (episode_returns_array, videos_list, latents_array) ep_returns, videos, latents = res return { "episode_returns": list(ep_returns), "videos": videos, "latents": latents, } else: # trainer.evaluate returns a dict with summary keys # We convert to repeated mean for compatibility if isinstance(res, dict) and "eval_mean_return" in res: return { "episode_returns": [ float(res["eval_mean_return"]) for _ in range(num_episodes) ] } return {"episode_returns": []}
[docs] class DreamerAdapter(BaseAdapter): def __init__(self, env_spec: Any | None = None, seed: int = 0, **kwargs): super().__init__(env_spec, seed) # env_spec can be dict or string. DreamerConfig expects env_backend and env. if isinstance(env_spec, dict): game = env_spec.get("game") env_backend = env_spec.get("env_backend", None) elif isinstance(env_spec, str): game = env_spec env_backend = None else: game = getattr(env_spec, "game", None) env_backend = getattr(env_spec, "env_backend", None) algo = kwargs.get("algo", "dreamerv1") config = kwargs.get("config", None) if config is None: cfg = DreamerConfig() else: cfg = config # If a gym-style game string was provided, set backend accordingly if game: cfg.env = game # If the game looks like an Atari id, prefer gym backend if isinstance(game, str) and ( "ALE/" in game or "-v" in game or "-v5" in game ): cfg.env_backend = "gym" if env_backend: cfg.env_backend = env_backend cfg.algo = ( "Dreamerv1" if algo.lower().startswith("dreamer") and "v2" not in algo.lower() else "Dreamerv2" ) cfg.seed = seed # Construct DreamerAgent (it will build envs internally) self.agent = DreamerAgent(config=cfg)
[docs] def load_checkpoint(self, path: str): try: if hasattr(self.agent, "dreamer") and hasattr( self.agent.dreamer, "restore_checkpoint" ): self.agent.dreamer.restore_checkpoint(path) else: raise AttributeError("Dreamer agent does not expose restore_checkpoint") except Exception: raise
[docs] def evaluate(self, num_episodes: int = 1, render: bool = False): # Configure agent's test episodes try: self.agent.args.test_episodes = int(num_episodes) except Exception: pass # Dreamer exposes dreamer.evaluate(env, eval_episodes, render=False) ep_rews, videos, latents = self.agent.dreamer.evaluate( self.agent.test_env, int(num_episodes), render=render ) out = {"episode_returns": list(ep_rews)} if render: out["videos"] = videos out["latents"] = latents return out
[docs] class DreamerV1Adapter(DreamerAdapter): def __init__(self, env_spec: Any | None = None, seed: int = 0, **kwargs): kwargs = dict(kwargs) kwargs.setdefault("algo", "dreamerv1") super().__init__(env_spec=env_spec, seed=seed, **kwargs)
[docs] class DreamerV2Adapter(DreamerAdapter): def __init__(self, env_spec: Any | None = None, seed: int = 0, **kwargs): kwargs = dict(kwargs) kwargs.setdefault("algo", "dreamerv2") super().__init__(env_spec=env_spec, seed=seed, **kwargs)