API Reference#

This reference is generated from source docstrings and grouped by subsystem.

Core Public APIs#

class world_models.configs.DreamerConfig[source]#

Bases: object

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/Gym/Unity), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

class world_models.configs.JEPAConfig[source]#

Bases: object

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: object

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.configs.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)

class world_models.configs.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = 'cuda', log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, num_seeds: int = 5, seed: int = 0)[source]#

Bases: object

Parameters:
  • preset (str | None)

  • game (str)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • num_seeds (int)

  • seed (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
device: str = 'cuda'#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
num_seeds: int = 5#
seed: int = 0#
world_models.envs.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#

Create any Atari environment from Arcadic Learning Environment (ALE).

Parameters:
  • env_id (str) – The id of the Atari environment to create.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The created Atari environment.

Return type:

gym.Env

world_models.envs.list_available_atari_envs()[source]#

Get a list of all available Atari environments in Arcadic Learning Environment (ALE).

Returns:

List of available Atari environment IDs.

Return type:

list[str]

world_models.envs.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#

Create vectorized Atari environments using ALE’s native AtariVectorEnv.

Parameters:
  • game (str) – The name of the Atari game (e.g., “pong”, “breakout”).

  • num_envs (int) – Number of parallel environments.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • seed (Optional[int]) – Random seed for reproducibility.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The vectorized Atari environment.

Return type:

AtariVectorEnv

world_models.envs.make_humanoid_env(version='v4', xml_file='humanoid.xml', forward_reward_weight=1.25, ctrl_cost_weight=0.1, contact_cost_weight=5e-07, healthy_reward=5.0, terminate_when_unhealthy=True, healthy_z_range=(1.0, 2.0), reset_noise_scale=0.01, exclude_current_positions_from_observation=True, include_cinert_in_observation=True, include_cvel_in_observation=True, include_qfrc_actuator_in_observation=True, include_cfrc_ext_in_observation=True)[source]#

Create a Humanoid environment with customizable parameters.

Parameters:
  • version (str) – The version of the Humanoid environment (e.g., “v4”).

  • xml_file (str) – The XML file defining the Humanoid model.

  • forward_reward_weight (float) – Weight for the forward reward.

  • ctrl_cost_weight (float) – Weight for the control cost.

  • contact_cost_weight (float) – Weight for the contact cost.

  • healthy_reward (float) – Reward for being in a healthy state.

  • terminate_when_unhealthy (bool) – Whether to terminate the episode when unhealthy.

  • healthy_z_range (Tuple[float, float]) – The range of z-values considered healthy.

  • reset_noise_scale (float) – Scale of noise added during environment reset.

  • exclude_current_positions_from_observation (bool) – Whether to exclude current positions from observations.

  • include_cinert_in_observation (bool) – Whether to include inertia in observations.

  • include_cvel_in_observation (bool) – Whether to include velocity in observations.

  • include_qfrc_actuator_in_observation (bool) – Whether to include actuator forces in observations.

  • include_cfrc_ext_in_observation (bool) – Whether to include external forces in observations.

Returns:

The created Humanoid environment.

Return type:

gym.Env

world_models.envs.make_half_cheetah_env(version='v4', forward_reward_weight=0.1, reset_noise_scale=0.1, exclude_current_positions_from_observation=True, render_mode='rgb_array')[source]#

Create a HalfCheetah environment with customizable parameters.

Parameters:
  • version (str) – The version of the HalfCheetah environment (e.g., “v4”).

  • forward_reward_weight (float) – Weight for the forward reward.

  • reset_noise_scale (float) – Scale of noise added during environment reset.

  • exclude_current_positions_from_observation (bool) – Whether to exclude current positions from observations.

  • render_mode (str) – The render mode for the environment.

Returns:

The created HalfCheetah environment.

Return type:

gym.Env

class world_models.envs.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#

Bases: object

Gym-like environment wrapper that always returns image observations.

  • Supports env IDs (string) and prebuilt env objects.

  • For vector observations, it synthesizes an RGB image so pixel-based world models can still train.

  • For discrete actions, it exposes a vector action space and maps by argmax.

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.make_gym_env(env, **kwargs)[source]#

Factory helper for generic Gym/Gymnasium environments.

class world_models.envs.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#

Bases: object

Gym-like wrapper for a Unity ML-Agents environment.

Notes: - Supports single-agent control. - Supports continuous action spaces. - Returns channel-first uint8 images in obs[“image”] for Dreamer-style pipelines.

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.make_unity_mlagents_env(**kwargs)[source]#

Factory helper for Unity ML-Agents environments.

class world_models.envs.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#

Bases: object

Gym-style adapter for DeepMind Control Suite tasks.

The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.

property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
render(*args, **kwargs)[source]#
class world_models.envs.TimeLimit(env, duration)[source]#

Bases: object

Terminate episodes after a fixed number of wrapper steps.

If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.

step(action)[source]#
reset()[source]#
class world_models.envs.ActionRepeat(env, amount)[source]#

Bases: object

Repeat each action for a fixed number of environment steps.

Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.

step(action)[source]#
class world_models.envs.NormalizeActions(env)[source]#

Bases: object

Expose a normalized [-1, 1] action space for bounded continuous controls.

Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.

property action_space#
step(action)[source]#
class world_models.envs.ObsDict(env, key='obs')[source]#

Bases: object

Convert scalar/array observations into a dictionary observation format.

This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).

property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.OneHotAction(env)[source]#

Bases: object

Wrap discrete-action environments to accept one-hot action vectors.

The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.

property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.RewardObs(env)[source]#

Bases: object

Augment observations with the latest scalar reward under obs[“reward”].

Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.

property observation_space#
step(action)[source]#
reset()[source]#
class world_models.envs.ResizeImage(env, size=(64, 64))[source]#

Bases: object

Resize image-like observation entries to a target spatial size.

The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.RenderImage(env, key='image')[source]#

Bases: object

Inject RGB renders from env.render(“rgb_array”) into observations.

This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.SelectAction(env, key)[source]#

Bases: Wrapper

Gym wrapper for dictionary actions that forwards a selected key only.

This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.

step(action)[source]#

Dreamer#

world_models.models.dreamer.make_env(args)[source]#

Construct a Dreamer-compatible environment from DreamerConfig options.

Supports DMC, Gym/Gymnasium, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit.

world_models.models.dreamer.preprocess_obs(obs)[source]#

Convert raw uint8 image observations to Dreamer float input space.

Images are scaled from [0, 255] to roughly [-0.5, 0.5], matching the normalization expected by Dreamer encoders.

class world_models.models.dreamer.Dreamer(args, obs_shape, action_size, device, restore=False)[source]#

Bases: object

Core Dreamer training system combining world model, actor, and value nets.

This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O.

world_model_loss(obs, acs, rews, nonterms)[source]#
actor_loss()[source]#
value_loss()[source]#
train_one_batch()[source]#
act_with_world_model(obs, prev_state, prev_action, explore=False)[source]#
act_and_collect_data(env, collect_steps)[source]#
evaluate(env, eval_episodes, render=False)[source]#
collect_random_episodes(env, seed_steps)[source]#
save(save_path)[source]#
restore_checkpoint(ckpt_path)[source]#
class world_models.models.dreamer.DreamerAgent(config=None, **kwargs)[source]#

Bases: object

High-level user API for running Dreamer experiments end to end.

It builds environments from config, initializes seeds and logging, instantiates Dreamer, and exposes simple train() / evaluate() methods.

train(total_steps=None)[source]#
evaluate()[source]#
class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer latent dynamics learning.

The RSSM maintains deterministic recurrent state and stochastic latent state, and provides transition/posterior updates plus rollout helpers.

init_state(batch_size, device)[source]#
get_dist(mean, std)[source]#
observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
observe_rollout(obs_embed, actions, nonterms, prev_state, horizon)[source]#
imagine_rollout(actor, prev_state, horizon)[source]#
stack_states(states, dim=0)[source]#
detach_state(state)[source]#
seq_to_batch(state)[source]#
class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

Encodes image observations into compact embeddings consumed by the RSSM posterior update network.

forward(inputs)[source]#
class world_models.vision.dreamer_decoder.TanhBijector[source]#

Bases: Transform

Bijective tanh transform used to squash Gaussian actions to [-1, 1].

Provides stable inverse and log-determinant Jacobian computations for transformed-action distributions.

property sign#
atanh(x)[source]#
log_abs_det_jacobian(x, y)[source]#
class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional Dreamer decoder from latent features to image distributions.

Maps concatenated stochastic/deterministic states into transposed-conv outputs and returns a factorized Normal distribution over pixels.

forward(features)[source]#
class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

The output distribution type is configurable (normal, binary, or raw tensor).

forward(features)[source]#
class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

property name#
mean()[source]#
mode()[source]#
entropy()[source]#
sample()[source]#
class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

forward(features, deter=False)[source]#
add_exploration(action, action_noise=0.3)[source]#

JEPA and ViT#

world_models.models.vit.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate fixed 2D sine/cosine positional embeddings on a square patch grid.

Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding.

world_models.models.vit.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)[source]#

Build 2D sine/cosine embeddings from precomputed meshgrid coordinates.

The final embedding concatenates independent encodings for vertical and horizontal coordinates.

world_models.models.vit.get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate 1D sine/cosine positional embeddings for integer positions.

Useful for sequence-style positional encoding and as a building block for 2D embedding construction.

world_models.models.vit.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#

Generate 1D sine/cosine positional embeddings from explicit positions.

Positions are projected onto a log-frequency basis and encoded with sine and cosine components.

world_models.models.vit.drop_path(x, drop_prob=0.0, training=False)[source]#

Apply stochastic depth (DropPath) regularization to residual branches.

Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude.

Parameters:
  • drop_prob (float)

  • training (bool)

class world_models.models.vit.DropPath(drop_prob=None)[source]#

Bases: Module

Module wrapper around the functional drop_path stochastic depth utility.

forward(x)[source]#
class world_models.models.vit.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#

Bases: Module

Two-layer feed-forward network used inside transformer blocks.

Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern.

forward(x)[source]#
class world_models.models.vit.Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Multi-head self-attention block for token sequences.

Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout.

forward(x)[source]#
class world_models.models.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Transformer encoder block combining attention and MLP residual branches.

Each branch uses pre-normalization and optional stochastic depth.

forward(x)[source]#
class world_models.models.vit.PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768)[source]#

Bases: Module

Image to Patch Embedding

forward(x)[source]#
class world_models.models.vit.ConvEmbed(channels, strides, img_size=224, in_chans=3, batch_norm=True)[source]#

Bases: Module

3x3 Convolution stems for ViT following ViTC models

forward(x)[source]#
class world_models.models.vit.VisionTransformerPredictor(num_patches, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

fix_init_weight()[source]#
forward(x, masks_x, masks)[source]#
class world_models.models.vit.VisionTransformer(img_size=[224], patch_size=16, in_chans=3, embed_dim=768, predictor_embed_dim=384, depth=12, predictor_depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

fix_init_weight()[source]#
forward(x, masks=None)[source]#
interpolate_pos_encoding(x, pos_embed)[source]#
world_models.models.vit.vit_predictor(**kwargs)[source]#

Factory for a JEPA predictor transformer with sensible defaults.

world_models.models.vit.vit_tiny(patch_size=16, **kwargs)[source]#

Factory for a tiny Vision Transformer encoder backbone.

world_models.models.vit.vit_small(patch_size=16, **kwargs)[source]#

Factory for a small Vision Transformer encoder backbone.

world_models.models.vit.vit_base(patch_size=16, **kwargs)[source]#

Factory for a base Vision Transformer encoder backbone.

world_models.models.vit.vit_large(patch_size=16, **kwargs)[source]#

Factory for a large Vision Transformer encoder backbone.

world_models.models.vit.vit_huge(patch_size=16, **kwargs)[source]#

Factory for a huge Vision Transformer encoder backbone.

world_models.models.vit.vit_giant(patch_size=16, **kwargs)[source]#

Factory for a giant Vision Transformer encoder backbone.

class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]#

Bases: object

Generate multi-block encoder and predictor masks for JEPA training.

For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.

step()[source]#
class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]#

Bases: object

Generate random context/prediction patch splits for masked training.

A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.

step()[source]#

IRIS (Sample-Efficient World Models)#

IRIS implements “Transformers are Sample-Efficient World Models” - a method that achieves human-level performance on Atari with only 100k environment steps (~2 hours of gameplay) by learning entirely in the imagination of a world model.

Architecture: - Discrete autoencoder (VQVAE) compresses frames to tokens - Autoregressive Transformer models dynamics - Actor-Critic trains entirely in imagined trajectories

class world_models.configs.iris_config.IRISConfig[source]#

Bases: object

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
get_autoencoder_config()[source]#
get_transformer_config()[source]#
get_rl_config()[source]#
world_models.models.iris_agent.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#

Compute λ-return target for value function training.

Parameters:
  • rewards (Tensor) – Rewards (B, T)

  • values (Tensor) – Value estimates (B, T+1)

  • discounts (Tensor) – Discount factors (B, T)

  • lambda_coef (float) – Lambda parameter for bootstrapping

Returns:

λ-return targets (B, T)

Return type:

lambda_returns

class world_models.models.iris_agent.IRISAgent(config, action_size, device)[source]#

Bases: Module

Complete IRIS Agent with world model and policy.

Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning

Parameters:
  • config (IRISConfig)

  • action_size (int)

  • device (device)

forward_actor_critic(frames, hidden=None)[source]#

Forward pass through actor-critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state

Returns:

(B, T, action_size) values: (B, T) hidden_state: (h, c)

Return type:

action_logits

act(frame, epsilon=0.0, temperature=1.0)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • epsilon (float) – Random action probability

  • temperature (float) – Action distribution temperature

Returns:

Selected actions (B,)

Return type:

actions

imagine_rollout(initial_frame, horizon=20)[source]#

Generate imagined trajectories using world model.

Parameters:
  • initial_frame (Tensor) – Starting frame (B, C, H, W)

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined rollout data

Return type:

trajectory

update_autoencoder(frames)[source]#

Update discrete autoencoder.

Parameters:

frames (Tensor) – Training frames (B, C, H, W)

Returns:

Dictionary of loss values

Return type:

losses

update_transformer(frames, actions, rewards, terminals)[source]#

Update transformer world model.

Parameters:
  • frames (Tensor) – Frame sequence

  • actions (Tensor) – Actions taken

  • rewards (Tensor) – Rewards received

  • terminals (Tensor) – Terminal flags

Returns:

Dictionary of loss values

Return type:

losses

update_actor_critic(imagined_trajectory)[source]#

Update actor-critic in imagination.

Parameters:

imagined_trajectory (dict) – Dictionary from imagine_rollout

Returns:

Dictionary of loss values

Return type:

losses

save(path)[source]#

Save agent state.

Parameters:

path (str)

load(path)[source]#

Load agent state.

Parameters:

path (str)

class world_models.models.iris_transformer.IRISTransformer(vocab_size=512, tokens_per_frame=16, action_size=18, embed_dim=256, num_layers=10, num_heads=4, dropout=0.1)[source]#

Bases: Module

Autoregressive Transformer for world modeling.

Models the dynamics of the environment by predicting: - Next frame tokens (transition model) - Rewards - Episode termination

The Transformer operates on sequences of interleaved frame tokens and actions.

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • action_size (int)

  • embed_dim (int)

  • num_layers (int)

  • num_heads (int)

  • dropout (float)

forward(tokens, actions, mask=None)[source]#

Forward pass through the Transformer world model.

Parameters:
  • tokens (Tensor) – Frame tokens (B, T, K) where T is timesteps

  • actions (Tensor) – Actions (B, T)

  • mask (Tensor | None) – Optional attention mask

Returns:

Next token predictions (B, T, K, vocab_size) rewards: Predicted rewards (B, T) terminations: Predicted terminations (B, T, 2)

Return type:

token_logits

predict_next_tokens(tokens, actions)[source]#

Predict the next frame tokens autoregressively.

Used during imagination rollouts.

Parameters:
  • tokens (Tensor) – Current frame tokens (B, K)

  • actions (Tensor) – Actions taken (B,)

Returns:

Next frame token predictions (B, K, vocab_size) action_hidden: Hidden states for reward prediction (B, embed_dim)

Return type:

token_logits

sample_next_tokens(tokens, actions, temperature=1.0)[source]#

Sample next tokens from the distribution.

Parameters:
  • tokens (Tensor) – Current frame tokens (B, K)

  • actions (Tensor) – Actions taken (B,)

  • temperature (float) – Sampling temperature (higher = more random)

Returns:

Sampled token indices (B, K) log_probs: Log probabilities of sampled tokens (B, K)

Return type:

sampled_tokens

class world_models.models.iris_transformer.IRISWorldModel(encoder, decoder, transformer)[source]#

Bases: Module

Complete IRIS World Model combining autoencoder and transformer.

This is the core component that learns environment dynamics entirely in the “imaginary” latent space.

Parameters:
forward(observations, actions)[source]#

Full world model forward pass.

Parameters:
  • observations (Tensor) – Image sequence (B, T+1, C, H, W)

  • actions (Tensor) – Actions (B, T)

Returns:

Dictionary with predicted tokens, rewards, terminations losses: Dictionary with loss components

Return type:

predictions

imagine(initial_tokens, policy, horizon=20, temperature=1.0)[source]#

Generate imagined trajectories.

Parameters:
  • initial_tokens (Tensor) – Initial frame tokens (B, K)

  • policy (Module) – Policy network to sample actions

  • horizon (int) – Number of steps to imagine

  • temperature (float) – Sampling temperature for token prediction

Returns:

Dictionary with imagined trajectories

Return type:

imagined

class world_models.vision.iris_encoder.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Encoder for IRIS discrete autoencoder.

Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.

Architecture:
  • 4 convolutional layers with residual blocks

  • Self-attention at 8x8 and 16x16 resolutions

  • Vector quantization to produce discrete tokens

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • in_channels (int)

  • base_channels (int)

  • num_residual_blocks (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Encode images to discrete tokens.

Parameters:

x (Tensor) – Input images (B, C, H, W) - should be 64x64

Returns:

Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

encode_to_indices(x)[source]#

Encode directly to token indices (for world model).

Parameters:

x (Tensor)

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode token indices to embeddings (for decoder).

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.vision.iris_encoder.ResidualBlock(channels)[source]#

Bases: Module

Residual block for encoder.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_encoder.SelfAttentionBlock(channels)[source]#

Bases: Module

Self-attention block for encoder.

Applies spatial self-attention to capture long-range dependencies.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Decoder for IRIS discrete autoencoder.

Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • base_channels (int)

  • out_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(z)[source]#

Decode tokens to images.

Parameters:

z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)

Returns:

Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)

Return type:

reconstructed

decode_from_embeddings(z_flat)[source]#

Decode flattened token embeddings to images.

Parameters:

z_flat (Tensor) – Flattened tokens (B, H*W, C) or (B, C, H, W)

Returns:

Reconstructed images

Return type:

Tensor

class world_models.vision.iris_decoder.UpsampleBlock(in_channels, mid_channels, out_channels)[source]#

Bases: Module

Upsampling block with optional residual connection.

Parameters:
  • in_channels (int)

  • mid_channels (int)

  • out_channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.ResidualBlock(channels)[source]#

Bases: Module

Residual block for decoder.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.DiscreteAutoencoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, base_channels=64, frame_shape=(3, 64, 64))[source]#

Bases: Module

Complete Discrete Autoencoder combining encoder and decoder.

Used for training the VQVAE component of IRIS.

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • base_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Full encode-decode forward pass.

Parameters:

x (Tensor) – Input images (B, C, H, W)

Returns:

Reconstructed images indices: Token indices (B, H’, W’) loss_dict: Dictionary with loss components

Return type:

reconstruction

encode(x)[source]#

Encode to token indices.

Parameters:

x (Tensor)

Return type:

Tensor

decode(indices)[source]#

Decode token indices to images.

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.vision.vq_layer.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#

Bases: Module

Vector Quantizer for discrete autoencoder.

Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)

Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

forward(z)[source]#

Quantize the input latents.

Parameters:

z (Tensor) – Input tensor of shape (B, C, H, W) or (B, C)

Returns:

Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices back to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.vision.vq_layer.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#

Bases: Module

Vector Quantizer with Exponential Moving Average updates.

Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • ema_decay (float)

  • epsilon (float)

forward(z)[source]#

Quantize with EMA updates.

Parameters:

z (Tensor)

Return type:

Tuple[Tensor, Tensor, dict]

decode_indices(indices)[source]#

Decode token indices to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.memory.iris_memory.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#

Bases: object

Replay buffer for IRIS training.

Stores (observation, action, reward, terminal) tuples and supports sampling contiguous sequences used by the world model and actor-critic.

Parameters:
  • size (int)

  • obs_shape (Tuple[int, int, int])

  • action_size (int)

  • seq_len (int)

  • batch_size (int)

add(obs, action, reward, terminal)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

sample_sequence(seq_len=None)[source]#

Sample a batch of sequences for world model training.

Returns:

(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)

Return type:

observations

Parameters:

seq_len (int | None)

sample_single()[source]#

Sample a single transition for online updates.

Return type:

Tuple[ndarray, ndarray, float, float]

property buffer_capacity#

Returns the total capacity of the buffer.

class world_models.memory.iris_memory.IRISOnPolicyBuffer(max_steps=1000)[source]#

Bases: object

Buffer for collecting trajectories during environment interaction.

Used to store current episode data before adding to main replay buffer.

Parameters:

max_steps (int)

add(obs, action, reward, terminal)[source]#
Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

clear()[source]#
get_arrays()[source]#
class world_models.controller.iris_policy.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Actor network for IRIS policy.

Takes reconstructed frames as input and outputs action logits. Uses CNN + LSTM architecture with burn-in mechanism.

Parameters:
  • action_size (int)

  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None, burn_in_frames=None)[source]#

Forward pass through actor.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state

  • burn_in_frames (Tensor | None) – Frames to use for initializing hidden state

Returns:

Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple

Return type:

action_logits

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_action(frame, temperature=1.0, deterministic=False)[source]#

Get action from a single frame.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • temperature (float) – Softmax temperature (higher = more random)

  • deterministic (bool) – If True, return argmax; else sample

Returns:

Selected action indices (B,)

Return type:

action

class world_models.controller.iris_policy.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Critic network for IRIS value estimation.

Shares CNN and LSTM with actor, but has separate value head.

Parameters:
  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None)[source]#

Forward pass through critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple

Returns:

Value estimates (B, T) hidden_state: Updated (h, c) tuple

Return type:

values

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class world_models.controller.iris_policy.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#

Bases: Module

CNN feature extractor shared between actor and critic.

Processes input frames into feature vectors.

Parameters:
  • frame_shape (Tuple[int, int, int])

  • output_size (int)

forward(x)[source]#

Extract features from frames.

Parameters:

x (Tensor) – Frames (B, C, H, W)

Returns:

Feature vectors (B, output_size)

Return type:

features

class world_models.controller.iris_policy.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Combined policy module for IRIS.

Wraps actor and optionally critic for convenience.

Parameters:
  • action_size (int)

  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames)[source]#

Get action logits from frames.

Parameters:

frames (Tensor)

Return type:

Tensor

act(frame, temperature=1.0, deterministic=False)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor)

  • temperature (float)

  • deterministic (bool)

Return type:

Tensor

init_hidden(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

class world_models.training.train_iris.IRISTrainer(game='ALE/Pong-v5', device='cuda', seed=42, config=None)[source]#

Bases: object

Training loop for IRIS on Atari 100k benchmark.

Parameters:
  • game (str)

  • device (str)

  • seed (int)

  • config (IRISConfig)

preprocess_frame(frame)[source]#

Preprocess frame: resize to 64x64 and normalize.

Parameters:

frame (ndarray)

Return type:

ndarray

collect_experience(num_steps, epsilon=0.01)[source]#

Collect experience from environment.

Parameters:
  • num_steps (int) – Number of steps to collect

  • epsilon (float) – Random action probability

Returns:

Mean episode return

Return type:

float

train_epoch(epoch)[source]#

Train for one epoch.

Parameters:

epoch (int) – Current epoch number

Returns:

Dictionary of metrics

Return type:

dict

get_epsilon(epoch)[source]#

Get exploration epsilon with decay.

Parameters:

epoch (int)

Return type:

float

evaluate(num_episodes=100, render=False)[source]#

Evaluate agent performance.

Parameters:
  • num_episodes (int) – Number of evaluation episodes

  • render (bool) – If True, also return video frames and per-step latent vectors

Returns:

dict with evaluation metrics If render is True: tuple (episode_returns_array, videos_list, latents_array)

Return type:

If render is False (default)

train(total_epochs=None, eval_interval=50, save_dir='checkpoints/iris')[source]#

Full training loop.

Parameters:
  • total_epochs (int) – Total training epochs

  • eval_interval (int) – Evaluate every N epochs

  • save_dir (str) – Directory to save checkpoints

world_models.training.train_iris.main()[source]#

Run IRIS training on a single Atari game.

Benchmarks#

IRIS Atari 100k Benchmark Runner

Runs IRIS on all 26 Atari games from the Atari 100k benchmark. Computes human-normalized scores and compares to baselines.

Based on paper: “Transformers are Sample-Efficient World Models”

benchmarks.atari_100k.compute_human_normalized_score(score, game)[source]#

Compute human-normalized score for a game.

Formula: (score - random) / (human - random)

Parameters:
  • score (float)

  • game (str)

Return type:

float

benchmarks.atari_100k.run_single_game(game, config=None, device='cuda', num_seeds=5)[source]#

Run IRIS on a single game with multiple seeds.

Parameters:
  • game (str) – Game name (e.g., “ALE/Pong-v5”)

  • config (IRISConfig) – IRIS config

  • device (str) – Device to run on

  • num_seeds (int) – Number of random seeds to run

Returns:

Dictionary with results

Return type:

Dict

benchmarks.atari_100k.run_atari_100k(games=None, config=None, device='cuda', output_file='results/iris_atari100k.json', num_seeds=5)[source]#

Run IRIS on all Atari 100k games.

Parameters:
  • games (List[str]) – List of games to run (default: all 26)

  • config (IRISConfig) – IRIS config

  • device (str) – Device to run on

  • output_file (str) – Output JSON file for results

  • num_seeds (int) – Number of seeds per game

benchmarks.atari_100k.print_results_table(results)[source]#

Print a nice table of results.

Parameters:

results (List[Dict])

benchmarks.atari_100k.main()[source]#

Run full Atari 100k benchmark.

Diffusion#

class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end, device)[source]#

Bases: object

Utility class implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

q_sample(x_start, t, noise=None)[source]#
p_sample(model, x_t, t)[source]#
sample(model, n, img_size, channels)[source]#
world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

Embeddings are scaled relative to configured diffusion timesteps and are consumed by the DiT conditioning MLP.

class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

A strided convolution performs patch extraction and projects each patch to the transformer hidden dimension with additive positional embeddings.

forward(x)[source]#
class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

forward(x)[source]#
class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]#

Bases: Module

Conditioned transformer block used inside the DiT backbone.

Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.

forward(x, t_emb)[source]#
class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

forward(x, t)[source]#
classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#

Datasets and Transforms#

world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]#

Create CIFAR-10 dataset objects and a distributed-capable dataloader.

Returns the dataset, sampler, and loader configured with the provided transform/collator so callers can plug the loader directly into JEPA or diffusion training loops.

world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]#

Build an ImageNet-1K dataset and dataloader with distributed sampling support.

This helper optionally restricts data to a subset file and returns the (dataset, dataloader, sampler) tuple used by training scripts.

class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]#

Bases: ImageFolder

ImageNet dataset wrapper with optional local copy/extract workflow.

The class extends torchvision.datasets.ImageFolder and can stage data from shared storage into local scratch space for faster multi-process training on cluster environments.

class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]#

Bases: object

View over an ImageNet dataset filtered by an explicit image-id list.

The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.

filter_dataset_(subset_file)[source]#

Filter self.dataset to a subset

property classes#
world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]#

Copy and extract ImageNet archives to per-job local scratch storage.

In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.

world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]#

Create an ImageFolder dataset loader for custom folder-structured datasets.

Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.

Parameters:

val_split (float | None)

world_models.transforms.transforms.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]#

Compose image augmentations and normalization for vision model training.

Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.

class world_models.transforms.transforms.GaussianBlur(p=0.5, radius_min=0.1, radius_max=2.0)[source]#

Bases: object

Probabilistic Gaussian blur augmentation for PIL images.

Applies blur with random radius in a configurable range when sampled.

Memory and Controllers#

class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer image observations and transitions.

Stores (observation image, action, reward, terminal) tuples and supports sampling contiguous sequences used by world-model unroll training.

add(obs, ac, rew, done)[source]#
sample()[source]#
class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode. At termination, it converts all the data to Numpy arrays.

property size#
append(obs, act, reward, terminal)[source]#
terminate(obs)[source]#
class world_models.memory.planet_memory.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Episodes are stored as variable-length trajectories and sampled as sub-sequences with optional time-major formatting for sequence models.

property size#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

sample(batch_size, tracelen=1, time_first=False)[source]#
class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller that plans actions with the RSSM latent model.

The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.

reset()[source]#
poll(observation, explore=False)[source]#

Utilities#

world_models.utils.dreamer_utils.get_parameters(modules)[source]#

Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]#

Bases: object

Context manager that temporarily disables gradients for given modules.

Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.Logger(log_dir, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', video_format='gif', video_fps=20)[source]#

Bases: object

Experiment logger for scalars and GIF rollouts using WandB.

Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.

log_scalar(scalar, name, step_)[source]#
log_scalars(scalar_dict, step)[source]#
log_videos(videos, step, max_videos_to_save=1, fps=None, video_title='video')[source]#
dump_scalars_to_pickle(metrics, step, log_title=None)[source]#
flush()[source]#
world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]#

Compute TD(lambda) returns from imagined rewards, values, and discounts.

Implements backward recursion used by Dreamer actor/value objectives.

world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]#

Initialize a tensor in-place from a truncated normal distribution.

Values are sampled from N(mean, std) and clipped to [a, b].

world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]#

Repeat each batch chunk multiple times while preserving chunk ordering.

Used in JEPA masking code to align context and target token batches.

class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]#

Bases: object

Learning-rate schedule with linear warmup followed by cosine decay.

Updates optimizer parameter-group LRs on each call to step().

step()[source]#
class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]#

Bases: object

Cosine scheduler for optimizer weight decay values.

Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.

step()[source]#
world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]#

Measure CUDA execution time for a closure and return (result, elapsed_ms).

Falls back to -1 elapsed time when CUDA timing is unavailable.

class world_models.utils.jepa_utils.CSVLogger(fname, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', *argv)[source]#

Bases: object

Lightweight CSV logger with per-column printf-style formatting and WandB support.

log(step, *argv)[source]#
class world_models.utils.jepa_utils.AverageMeter[source]#

Bases: object

Track running statistics (val, avg, min, max, sum, count) for metrics.

reset()[source]#
update(val, n=1)[source]#
world_models.utils.jepa_utils.grad_logger(named_params)[source]#

Aggregate gradient norm statistics over model parameters for logging.

Also exposes first/last qkv-layer gradient norms when available.

world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]#

Initialize torch distributed process groups when environment supports it.

Returns (world_size, rank) and gracefully falls back to single-process mode.

class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]#

Bases: Function

Autograd-aware all-gather operation across distributed workers.

Forward concatenates worker tensors; backward reduces and slices gradients.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]#

Bases: Function

Autograd function that sums tensors across distributed workers in forward pass.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]#

Bases: Function

Autograd function that all-reduces and averages tensors across workers.

Used to synchronize scalar losses for consistent distributed logging/training.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.utils.AttrDict[source]#

Bases: dict

world_models.utils.utils.load_yml_config(path)[source]#
world_models.utils.utils.to_tensor_obs(image)[source]#

Converts the input np img to channel first 64x64 dim torch img.

world_models.utils.utils.postprocess_img(image, depth)[source]#

Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])

world_models.utils.utils.preprocess_img(image, depth)[source]#

Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !!

world_models.utils.utils.bottle(func, *tensors)[source]#

Evaluates a func that operates in N x D with inputs of shape N x T x D

world_models.utils.utils.get_combined_params(*models)[source]#

Returns the combine parameter list of all the models given as input.

world_models.utils.utils.save_video(frames, path, name)[source]#

Saves a video containing frames.

Accepts frames in either:
  • (T, C, H, W) float in [0,1]

  • (T, H, W, C) float in [0,1]

Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout.

world_models.utils.utils.combine_videos(video_dir, output_name='combined.mp4', pattern='vid_*.mp4', fps=25, resize=True)[source]#

Combine all videos matching pattern in video_dir into a single MP4 file. Returns the output filepath (string).

Example

combine_videos(“results/planet”, output_name=”all_training.mp4”)

world_models.utils.utils.ensure_results_dir_exists(results_dir)[source]#

Simple helper to validate a results directory exists. Raises FileNotFoundError if not present.

world_models.utils.utils.save_frames(target, pred_prior, pred_posterior, name, n_rows=5)[source]#

Save side-by-side target, prior-prediction, and posterior-prediction frames.

The function accepts tensors with optional time dimension and writes a PNG grid to {name}.png. Spatial sizes are aligned per timestep before concatenation and values are normalized to [0, 1] when needed.

world_models.utils.utils.get_mask(tensor, lengths)[source]#

Build a batch-first validity mask from sequence lengths.

tensor may be a tensor/array with shape (N, T, ...) or (N,). The returned mask marks valid timesteps with ones up to each element in lengths and preserves device/dtype conventions from the input.

world_models.utils.utils.load_memory(path, device)[source]#

Loads an experience replay buffer (backwards-compatible with older pickle formats). Converts legacy list/.data formats into the current Memory(episodes) object.

world_models.utils.utils.flatten_dict(data, sep='.', prefix='')[source]#

Flattens a nested dict into a single-level dict.

Example

{‘a’: 2, ‘b’: {‘c’: 20}} -> {‘a’: 2, ‘b.c’: 20}

world_models.utils.utils.normalize_frames_for_saving(frames)[source]#

Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.

class world_models.utils.utils.TensorBoardMetrics(path)[source]#

Bases: object

Plots and (optionally) stores metrics for an experiment.

assign_type(key, val)[source]#
update(metrics)[source]#
Parameters:

metrics (dict)

world_models.utils.utils.apply_model(model, inputs, ignore_dim=None)[source]#

Placeholder helper for generic model application across input structures.

Currently not implemented; kept as an extension hook for future utility code.

world_models.utils.utils.plot_metrics(metrics, path, prefix)[source]#

Render and save line plots for each metric series in a dictionary.

world_models.utils.utils.lineplot(xs, ys, title, path='', xaxis='episode')[source]#

Create a Plotly line plot for scalar, dict, or ensemble-series data.

Supports uncertainty-band plotting when ys is a 2D array.

class world_models.utils.utils.TorchImageEnvWrapper(env, bit_depth, observation_shape=None, act_rep=2)[source]#

Bases: object

Torch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form.

reset()[source]#
step(u)[source]#
render()[source]#
close()[source]#
property observation_size#
property action_size#
sample_random_action()[source]#
property max_episode_steps#

Return environment max episode steps (compatible with TimeLimit/spec).

world_models.utils.utils.apply_masks(x, masks)[source]#

Gather token subsets from patch sequences using index masks.

Each mask selects token positions from x; selected groups are concatenated along the batch dimension.

world_models.utils.utils.visualize_latent_tsne(latents, labels=None, save_path=None, perplexity=30)[source]#

Visualize latent representations using t-SNE.

Parameters:
  • latents – torch.Tensor of shape (N, D) or numpy array

  • labels – optional list or array of labels for coloring

  • save_path – path to save the plot (HTML for plotly)

  • perplexity – t-SNE perplexity parameter

world_models.utils.utils.visualize_latent_umap(latents, labels=None, save_path=None, n_neighbors=15)[source]#

Visualize latent representations using UMAP.

Parameters:
  • latents – torch.Tensor of shape (N, D) or numpy array

  • labels – optional list or array of labels for coloring

  • save_path – path to save the plot (HTML for plotly)

  • n_neighbors – UMAP n_neighbors parameter

class world_models.utils.utils.StreamingVideoWriter(path, fps=20, frame_shape=None, format='mp4')[source]#

Bases: object

A class for streaming video writing to save frames in real-time.

Parameters:
  • path – output video file path

  • fps – frames per second

  • frame_shape – (height, width) of frames

  • format – ‘mp4’ or ‘avi’

write_frame(frame)[source]#

Write a single frame to the video.

Parameters:

frame – numpy array of shape (H, W, C) or (H, W), uint8 or float in [0,1]

close()[source]#