API Reference#

This reference is generated from source docstrings and grouped by workflow. Use World Models Study Guide for conceptual explanations and this page for exact classes, functions, and module-level APIs.

Public package surface#

These modules expose the most common imports and lazy constructors.

Primary module: torchwm. Implementation modules are documented below for API completeness.

Use torchwm for common workflows:

import torchwm
agent = torchwm.create_model("dreamer", env="walker-walk")

Friendly top-level package for TorchWM.

torchwm is the recommended public namespace for users:

import torchwm

agent = torchwm.create_model("dreamer", env="walker-walk")
class torchwm.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing an environment backend available through make_env.

Parameters:
  • name (str)

  • factory_path (str)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

factory_path: str#

Alias for field number 1

description: str#

Alias for field number 2

aliases: tuple[str, ...]#

Alias for field number 3

class torchwm.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing a model available through create_model().

Parameters:
  • name (str)

  • import_path (str)

  • config_path (str | None)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

import_path: str#

Alias for field number 1

config_path: str | None#

Alias for field number 2

description: str#

Alias for field number 3

aliases: tuple[str, ...]#

Alias for field number 4

torchwm.create_config(model, **overrides)[source]#

Create the default config object for model and apply overrides.

Examples

>>> cfg = create_config("dreamer", env="walker-walk", seed=7)
>>> cfg.env
'walker-walk'
Parameters:
  • model (str)

  • overrides (Any)

Return type:

Any

torchwm.create_model(model, config=None, **overrides)[source]#

Instantiate a model or agent from a simple string name.

config is optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.

Examples

>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000)
>>> genie = create_model("genie-small", image_size=32)
Parameters:
  • model (str)

  • config (Any | None)

  • overrides (Any)

Return type:

Any

torchwm.get_env_backend_spec(name)[source]#

Return metadata for an environment backend name or alias.

Parameters:

name (str)

Return type:

EnvBackendSpec

torchwm.get_model_spec(name)[source]#

Return metadata for a model name or alias.

Parameters:

name (str)

Return type:

ModelSpec

torchwm.list_env_backends()[source]#

Return canonical backend names accepted by make_env().

Return type:

list[str]

torchwm.list_envs(model=None)[source]#

List known environment ids, optionally filtered by model family.

Parameters:

model (str | None)

Return type:

list[str] | dict[str, list[str]]

torchwm.list_models()[source]#

Return canonical model names accepted by create_model().

Return type:

list[str]

torchwm.make_env(env_id, backend='auto', **kwargs)[source]#

Create an environment with a consistent TorchWM entry point.

Parameters:
  • env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.

  • backend (str) – One of list_env_backends(); "auto" tries TorchWM’s compatibility helper.

  • **kwargs (Any) – Backend-specific options.

Return type:

Any

class torchwm.IRISAgent(config, action_size, device)[source]#

Bases: Module

Complete IRIS Agent with world model and policy.

Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning

Parameters:
  • config (IRISConfig)

  • action_size (int)

  • device (device)

forward_actor_critic(frames, hidden=None)[source]#

Forward pass through actor-critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state

Returns:

(B, T, action_size) values: (B, T) hidden_state: (h, c)

Return type:

action_logits

act(frame, epsilon=0.0, temperature=1.0)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • epsilon (float) – Random action probability

  • temperature (float) – Action distribution temperature

Returns:

Selected actions (B,)

Return type:

actions

imagine_rollout(initial_frame, horizon=20)[source]#

Generate imagined trajectories using world model.

Parameters:
  • initial_frame (Tensor) – Starting frame (B, C, H, W)

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined rollout data

Return type:

trajectory

update_autoencoder(frames)[source]#

Update discrete autoencoder.

Parameters:

frames (Tensor) – Training frames (B, C, H, W)

Returns:

Dictionary of loss values

Return type:

losses

update_transformer(frames, actions, rewards, terminals)[source]#

Update transformer world model.

Parameters:
  • frames (Tensor) – Frame sequence

  • actions (Tensor) – Actions taken

  • rewards (Tensor) – Rewards received

  • terminals (Tensor) – Terminal flags

Returns:

Dictionary of loss values

Return type:

losses

update_actor_critic(imagined_trajectory)[source]#

Update actor-critic in imagination.

Parameters:

imagined_trajectory (dict) – Dictionary from imagine_rollout

Returns:

Dictionary of loss values

Return type:

losses

save(path)[source]#

Save agent state.

Parameters:

path (str)

load(path)[source]#

Load agent state.

Parameters:

path (str)

torchwm.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#

Compute λ-return target for value function training.

Parameters:
  • rewards (Tensor) – Rewards (B, T)

  • values (Tensor) – Value estimates (B, T+1)

  • discounts (Tensor) – Discount factors (B, T)

  • lambda_coef (float) – Lambda parameter for bootstrapping

Returns:

λ-return targets (B, T)

Return type:

lambda_returns

class torchwm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#

Bases: Module

Modular RSSM with swappable encoder, decoder, and backbone.

This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer

Example

>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024)
>>> decoder = ConvDecoder(32, 200, (3, 64, 64))
>>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024)
>>> rssm = ModularRSSM(encoder, decoder, backbone)
Parameters:
property stoch_size: int#
property deter_size: int#
property embed_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

get_dist(mean, std)[source]#
Parameters:
  • mean (Tensor)

  • std (Tensor)

Return type:

Distribution

observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • obs (Tensor)

  • nonterm (Any)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • nonterm (Any)

Return type:

Dict[str, Tensor]

observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
Parameters:
  • obs (Tensor)

  • actions (Tensor)

  • nonterms (Tensor)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_rollout(actor, prev_state, horizon)[source]#
Parameters:
  • actor (Module)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Dict[str, Tensor]

decode_observation(features)[source]#
Parameters:

features (Tensor)

decode_reward(features)[source]#
Parameters:

features (Tensor)

detach_state(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

seq_to_batch(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

torchwm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#

Factory function to create a modular RSSM with specified components.

Parameters:
  • encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)

  • decoder_type (str) – Type of decoder (“conv”, “mlp”)

  • backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)

  • obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state

  • action_size (int) – Action space dimension

  • stoch_size (int) – Stochastic latent dimension

  • deter_size (int) – Deterministic hidden dimension

  • embed_size (int) – Encoder embedding dimension

  • hidden_size (int) – Hidden layer dimension

  • activation (str) – Activation function name

Returns:

Configured ModularRSSM instance

Return type:

ModularRSSM

class torchwm.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Genie: Generative Interactive Environment.

A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions

Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).

Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)

The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse

At inference, latent actions are stopgrad’d when passed to dynamics model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_decoder_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • latent_action_depth (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

forward(video, mask_prob=0.5, training_phase='all')[source]#

Full forward pass through all components.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing losses and predictions

Return type:

Dict[str, Tensor]

training_step(video, mask_prob=0.5, training_phase='all')[source]#

Single training step computing all losses.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing all losses for backpropagation

Return type:

Dict[str, Tensor]

encode_video(video)[source]#

Encode video to discrete tokens.

Parameters:

video (Tensor) – (B, C, T, H, W)

Returns:

(B, T, H*W)

Return type:

video_tokens

infer_actions(frames)[source]#

Infer latent actions from a sequence of frames.

Parameters:

frames (Tensor) – (B, C, T, H, W) video frames

Returns:

(B, T-1) inferred latent action indices

Return type:

latent_actions

generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#

Generate video frames given a prompt frame and actions.

Parameters:
  • prompt_frame (Tensor) – (B, C, H, W) initial frame

  • num_frames (int) – Total number of frames to generate

  • actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random

  • use_maskgit (bool) – Whether to use MaskGIT sampling

Returns:

(B, C, num_frames, H, W)

Return type:

generated_video

play(current_frame, action, current_frames=None)[source]#

Play step - generate next frame given current frame and action.

Parameters:
  • current_frame (Tensor) – (B, C, H, W) current frame

  • action (Tensor) – (B,) latent action indices

  • current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame

Returns:

(B, C, H, W)

Return type:

next_frame

get_num_parameters()[source]#

Return total number of parameters.

Return type:

int

class torchwm.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Latent Action Model (LAM) for unsupervised action learning.

Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.

Based on Genie paper - learns actions without action labels from Internet videos.

Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

encode(x_prev, x_next)[source]#

Encode frames to latent actions.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)

Return type:

latent_actions

decode(x_prev, z_q)[source]#

Decode latent actions to predict next frame.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W) - will mask all but first

  • z_q (Tensor) – Quantized action embeddings (B, T-1, embedding_dim)

Returns:

(B, C, H, W)

Return type:

predicted_next_frame

forward(x_prev, x_next)[source]#

Full forward pass: encode to get actions, decode to reconstruct.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Dictionary with losses and outputs

Return type:

Dict[str, Tensor]

class torchwm.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: Module

Dynamics Model for action-controllable video generation.

A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.

Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

forward(video_tokens, actions, mask_prob=0.0)[source]#

Forward pass for training.

Parameters:
  • video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T

  • actions (Tensor) – (B, T) - latent action indices for frames 1 to T

  • mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)

Returns:

(B, T, H*W, vocab_size)

Return type:

logits

sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#

Sample future frames using MaskGIT.

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • sampler (MaskGITSampler | None) – MaskGIT sampler instance

Returns:

(B, num_frames, N)

Return type:

generated_tokens

autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#

Simple autoregressive sampling (token by token).

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • temperature (float) – Sampling temperature

Returns:

(B, num_frames, N)

Return type:

generated_tokens

torchwm.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

torchwm.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Create a smaller Genie model for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

torchwm.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#

Create the full 11B parameter Genie model (approximate).

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

torchwm.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

LatentActionModel

torchwm.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#

Factory function to create a Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

Return type:

DynamicsModel

class torchwm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer for latent dynamics learning.

The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:

  1. Deterministic State (h): A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.

  2. Stochastic State (s): A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).

The model operates in two modes:

  • Observe Mode: Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)

  • Imagine Mode: Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)

Architecture:

Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1} Process: GRU updates deterministic state, MLP computes stochastic prior/posterior Output: Updated state (h_t, s_t) and distributions

State Representation:
  • deter (h): GRU hidden state, captures sequential context

  • stoch (s): Stochastic latent, multi-modal uncertainty

  • mean/std: Parameters of the stochastic distribution

Usage with DreamerAgent:
rssm = RSSM(

action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation=’elu’

)

# Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed)

# Imagine future without observation prior = rssm.imagine_step(current_state, action)

Training:

The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to

capture environment dynamics

  • Reconstruction loss from decoder ensures state captures observation info

Reference:

Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603

init_state(batch_size, device)[source]#

Initialize RSSM state with zeros.

Parameters:
  • batch_size – Number of parallel sequences

  • device – torch device for tensors

Returns:

  • mean, std: Stochastic distribution parameters

  • stoch: Stochastic state sample

  • deter: Deterministic GRU hidden state

Return type:

Dictionary containing zero-initialized state components

get_dist(mean, std)[source]#

Create an Independent Normal distribution from mean and std.

Parameters:
  • mean – Location parameter

  • std – Scale parameter

Returns:

Independent Normal distribution with given parameters

observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#

Update state using actual observation (observe mode).

In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.

Parameters:
  • prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action – Previous action a_{t-1}, shape (B, action_size)

  • obs_embed – Observation embedding from encoder, shape (B, obs_embed_size)

  • nonterm – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

A tuple (posterior, prior) of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#

Predict next state without observation (imagine mode).

In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.

Parameters:
  • prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action – Previous action a_{t-1}, shape (B, action_size)

  • nonterm – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

  • deter: Predicted deterministic state

  • mean, std, stoch: Prior stochastic state distribution

Return type:

Dictionary with predicted state containing

get_prior(prev_state, prev_action, nonterm=1.0)[source]#

Compute prior distribution over stochastic state.

The prior represents the model’s belief about the stochastic state before observing the actual outcome.

Parameters:
  • prev_state – Previous state dictionary

  • prev_action – Previous action

  • nonterm – Termination mask

Returns:

Dictionary with prior state (no observation information)

get_posterior(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#

Compute posterior distribution over stochastic state.

The posterior incorporates observation information to produce a more accurate state estimate.

Parameters:
  • prev_state – Previous state dictionary

  • prev_action – Previous action

  • obs_embed – Observation embedding

  • nonterm – Termination mask

Returns:

Dictionary with posterior state (observation-informed). Note that the previous-state shape (B, ...) is preserved; the batch dimension is not flattened.

detach_state(state)[source]#

Detach state tensors from computation graph.

Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.

Parameters:

state – State dictionary with tensor values

Returns:

Detached state dictionary

seq_to_batch(state_dict)[source]#

Convert sequence state to batch format.

Parameters:

state_dict – Dictionary with sequence-dimension tensors (T, B, …)

Returns:

Dictionary with batch-dimension tensors (B*T, …)

observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#

Process a sequence of observations (observe mode rollout).

At each timestep we run observe_step once to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.

Parameters:
  • obs_embed – Observation embeddings, shape (T+1, B, obs_embed_size)

  • actions – Actions, shape (T, B, action_size)

  • nonterms – Non-termination flags, shape (T, B, 1)

  • init_state – Initial state dictionary

  • seq_len – Sequence length T

Returns:

Dictionary with prior states stacked along the time axis. posterior: Dictionary with posterior states stacked along the time

axis.

Return type:

prior

imagine_rollout(policy, init_state, horizon)[source]#

Generate imagined trajectory using policy (imagine mode rollout).

Parameters:
  • policy – Actor network that outputs actions from state features

  • init_state – Initial state dictionary

  • horizon – Number of steps to imagine

Returns:

Dictionary with imagined states for each step

forward(x, u)[source]#

Forward pass for training (computes sequence of states).

Parameters:
  • x – Observations, shape (B, T+1, C, H, W)

  • u – Actions, shape (B, T, action_size)

Returns:

List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)

Return type:

states

class torchwm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#

Bases: Module

A Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.

get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#

Returns the initial posterior given the observation.

deterministic_state_fwd(h_t, s_t, a_t)[source]#
Deterministic transition update that accepts:
  • a_t shaped [B, action_size]

  • a_t shaped [action_size] (unbatched) -> expanded to [B, action_size]

  • a_t shaped [B] or scalar -> reshaped appropriately

Ensures a_t is 2D and matches batch dimension of h_t before concatenation.

state_prior(h_t, sample=False)[source]#

Returns the prior distribution over the latent state given the deterministic state

state_posterior(h_t, e_t, sample=False)[source]#

Returns the state prior given the deterministic state and obs

pred_reward(h_t, s_t)[source]#
rollout_prior(act, h_t, s_t)[source]#
forward(x, u)[source]#

Forward through the RSSM for a batch of sequences. Inputs:

x: Tensor [B, T+1, C, H, W] (observations including initial frame) u: Tensor [B, T, action_size] (actions for T steps)

Returns:

list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]

Return type:

states

class torchwm.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).

Architecture:

Input: (B, C, H, W) raw images, values in [-0.5, 0.5] Process: 4 convolutional layers with stride 2, halving spatial dimensions Output: (B, embed_size) compact representation

The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.

Usage with Dreamer:
encoder = ConvEncoder(

input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation=’relu’ # Activation function

) obs_embedding = encoder(observation) # (B, 256)

Parameters:
  • input_shape – Tuple (C, H, W) for input images, typically (3, 64, 64)

  • embed_size – Output embedding dimension, typically 256 or 1024

  • activation – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)

  • depth – Base channel depth for first layer (default 32)

forward(inputs)[source]#
class torchwm.CNNEncoder(embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) encoder for processing image inputs.

forward(observation)[source]#
class torchwm.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional decoder for reconstructing observations from latent states.

Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.

Architecture:

Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter) Process: Dense projection + 4 transposed convolutions (upsampling 2x each) Output: Independent Normal distribution over observation pixels

The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.

Output Distribution:

Returns torch.distributions.Independent(Normal(mean, std), len(shape)) This allows computing log_prob(observation) for reconstruction loss.

Usage in Dreamer world model:
decoder = ConvDecoder(

stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation=’relu’

) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)

Training:

The reconstruction loss is: -log_prob(observation) This encourages the RSSM to learn states that capture observation information.

forward(features)[source]#
class torchwm.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) decoder for reconstructing image outputs.

forward(latent, state)[source]#
class torchwm.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.

Architecture:

Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter) Process: MLP with configurable layers and hidden units Output: Predicted quantity with distribution (normal, binary, or raw)

Supports three output types:
  • ‘normal’: Gaussian distribution for regression (rewards, values)

  • ‘binary’: Bernoulli distribution for binary classification (discount)

  • ‘none’: Raw tensor for non-probabilistic outputs

Usage:
reward_decoder = DenseDecoder(

stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’normal’

) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)

For discount prediction (binary):
discount_decoder = DenseDecoder(

stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’binary’ # Bernoulli for P(continue)

)

forward(features)[source]#
class torchwm.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

forward(features, deter=False)[source]#
add_exploration(action, action_noise=0.3)[source]#
class torchwm.TanhBijector[source]#

Bases: Transform

Bijective tanh transform for squashing Gaussian distributions to [-1, 1].

This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:

  1. Bijective mapping: tanh is invertible (with atanh as inverse)

  2. Stable log-det Jacobian: Computable for gradient-based training

  3. Clipped actions: During inference, actions are naturally bounded

Math:

Forward: y = tanh(x) Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y)) Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))

Usage with Dreamer ActionDecoder:
dist = TransformedDistribution(

Normal(mean, std), TanhBijector()

) action = dist.sample() # Bounded to [-1, 1]

Reference:

Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.

property sign#
atanh(x)[source]#
log_abs_det_jacobian(x, y)[source]#
class torchwm.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

property name#
mean()[source]#
mode()[source]#
entropy()[source]#
sample()[source]#
class torchwm.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Encoder for IRIS discrete autoencoder.

Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.

Architecture:
  • 4 convolutional layers with residual blocks

  • Self-attention at 8x8 and 16x16 resolutions

  • Vector quantization to produce discrete tokens

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • in_channels (int)

  • base_channels (int)

  • num_residual_blocks (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Encode images to discrete tokens.

Parameters:

x (Tensor) – Input images (B, C, H, W) - should be 64x64

Returns:

Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

encode_to_indices(x)[source]#

Encode directly to token indices (for world model).

Parameters:

x (Tensor)

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode token indices to embeddings (for decoder).

Parameters:

indices (Tensor)

Return type:

Tensor

class torchwm.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Decoder for IRIS discrete autoencoder.

Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • base_channels (int)

  • out_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(z)[source]#

Decode tokens to images.

Parameters:

z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)

Returns:

Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)

Return type:

reconstructed

decode_from_embeddings(z_flat)[source]#

Decode flattened token embeddings to images.

Parameters:

z_flat (Tensor) – Flattened tokens (B, H*W, C) or (B, C, H, W)

Returns:

Reconstructed images

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode discrete token indices into images.

Parameters:

indices (Tensor) – Tensor of shape (B, H, W) or (B, H*W) containing integer token indices in the range [0, vocab_size).

Returns:

Reconstructed images (B, C, H, W)

Return type:

Tensor

class torchwm.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#

Bases: Module

Video Tokenizer using VQ-VAE with Spatiotemporal Transformer.

This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.

The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.

Architecture:
  1. Patch Embedding: Convert (B, C, T, H, W) video to patch tokens

  2. Encoder ST-Transformer: Process spatial-temporal patches

  3. Vector Quantization: Discretize continuous embeddings to codebook entries

  4. Decoder ST-Transformer: Reconstruct video from quantized tokens

  5. Patch Unembedding: Convert tokens back to video frames

Key Features:
  • Causal processing: Each frame’s encoding only uses previous frames

  • Discrete tokens: Enables autoregressive prediction with latent actions

  • Memory efficient: Uses ST-Transformer instead of full ViT to reduce O(n²) complexity

Usage with Genie:
tokenizer = VideoTokenizer(

num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32

) reconstructed, indices, loss_dict = tokenizer(video_frames)

# For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)

Training:

The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings (encourages learning useful codes) - Commitment loss: Penalizes encoder outputs drifting from codebook

Reference:

Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • use_ema (bool)

  • ema_decay (float)

encode(x)[source]#

Encode video to discrete tokens.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices to embeddings for video frames.

Parameters:

indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’*W’

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim)

Return type:

z_q

decode(z_q)[source]#

Decode discrete tokens to video frames.

Parameters:

z_q (Tensor) – Quantized embeddings (B, T, H’, W’, embedding_dim)

Returns:

Reconstructed video (B, C, T, H, W)

Return type:

Tensor

forward(x)[source]#

Full forward pass with VQ-VAE objective.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Reconstructed video (B, C, T, H, W) indices: Token indices (B, T, H’, W’) loss_dict: Dictionary containing loss components

Return type:

reconstructed

torchwm.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#

Factory function to create a Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

Return type:

VideoTokenizer

class torchwm.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#

Bases: Module

Vector Quantizer for discrete autoencoder.

Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)

Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

forward(z)[source]#

Quantize the input latents.

Parameters:

z (Tensor) – Input tensor of shape (B, C, H, W) or (B, C)

Returns:

Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices back to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class torchwm.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#

Bases: Module

Vector Quantizer with Exponential Moving Average updates.

Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • ema_decay (float)

  • epsilon (float)

forward(z)[source]#

Quantize with EMA updates.

Parameters:

z (Tensor)

Return type:

Tuple[Tensor, Tensor, dict[str, Tensor]]

decode_indices(indices)[source]#

Decode token indices to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class torchwm.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer with image observations and transitions.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.

Key Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores raw uint8 images to save memory

  • Samples sequences (not single transitions) for temporal modeling

  • Validates sampled sequences don’t span episode boundaries

Memory Layout:
  • observations: (capacity, C, H, W) uint8 images

  • actions: (capacity, action_dim) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)

Sampling Process:
  1. Random start index (avoiding episode boundaries)

  2. Collect sequence of length seq_len with wraparound

  3. Validate no terminal in middle of sequence

  4. Return batch of sequences

Usage with Dreamer:
buffer = ReplayBuffer(

size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch

)

# Add transitions during interaction buffer.add(obs, action, reward, done)

# Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample() # Shapes: (seq_len, batch, C, H, W), (seq_len, batch, action_dim), etc.

Memory Efficiency:
  • Uses uint8 for images (1 byte per pixel vs 4 for float32)

  • Sequences share observations (overlapping windows)

  • Configurable capacity based on available system memory

Note

The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.

Parameters:
  • size (int)

  • obs_shape (Tuple[int, ...])

  • action_size (int)

  • seq_len (int)

  • batch_size (int)

add(obs, ac, rew, done)[source]#

Add a transition to the buffer.

Parameters:
  • obs (dict) – Observation dict with ‘image’ key containing the observation

  • ac (ndarray) – Action taken, shape (action_size,)

  • rew (float) – Reward received, scalar

  • done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise

Return type:

None

sample()[source]#

Sample a batch of sequences for training.

Returns:

(observations, actions, rewards, terminals)
  • observations: (seq_len, batch, C, H, W)

  • actions: (seq_len, batch, action_dim)

  • rewards: (seq_len, batch)

  • terminals: (seq_len, batch)

Return type:

tuple

class torchwm.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.

Features:
  • Stores complete episodes as lists of transitions

  • Samples contiguous sub-sequences for sequence models

  • Supports time-major formatting (time-first) for RNN input

  • Memory usage estimation to prevent OOM errors

Parameters:

size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).

episodes#

Collection of Episode objects.

Type:

deque

eps_lengths#

Length of each episode.

Type:

deque

size#

Total number of transitions across all episodes.

Type:

property

Example

>>> memory = Memory(size=100)
>>> memory.append([episode1, episode2])
>>> batch, lengths = memory.sample(batch_size=32, tracelen=50)
property size#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

sample(batch_size, tracelen=1, time_first=False)[source]#

Sample random sub-sequences from stored episodes.

Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.

Parameters:
  • batch_size (int) – Number of sequences to sample.

  • tracelen (int) – Length of each sequence (default: 1).

  • time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).

Returns:

(observations, actions, rewards, terminals, lengths)
  • observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)

  • actions: (batch, tracelen, action_dim) or (tracelen, batch, …)

  • rewards: (batch, tracelen) or (tracelen, batch)

  • terminals: (batch, tracelen) or (tracelen, batch)

  • lengths: (batch,) original episode lengths for each sample

Return type:

tuple

Raises:
  • ValueError – If memory is empty or no episodes meet minimum length.

  • MemoryError – If estimated memory usage exceeds 200 MiB threshold.

class torchwm.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode.

Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.

x#

Observations collected during the episode.

Type:

list or np.ndarray

u#

Actions taken.

Type:

list or np.ndarray

r#

Rewards received.

Type:

list or np.ndarray

t#

Terminal flags (0.0 = continue, 1.0 = terminal).

Type:

list or np.ndarray

info#

Additional episode metadata.

Type:

dict

Parameters:

postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.

Example

>>> episode = Episode()
>>> episode.append(obs, action, reward, False)
>>> episode.append(obs, action, reward, True)
>>> episode.terminate(final_obs)
>>> print(episode.x.shape)  # Now a numpy array
property size#
append(obs, act, reward, terminal)[source]#
terminate(obs)[source]#
class torchwm.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#

Bases: object

Replay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.

Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores uint8 images for memory efficiency

  • Samples sequences with validation to avoid episode boundaries

  • Supports sequence sampling for temporal learning

Memory Layout:
  • observations: (capacity, C, H, W) uint8

  • actions: (capacity, action_size) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32

Parameters:
  • size (int) – Maximum number of transitions to store.

  • obs_shape (tuple) – Shape of observations as (C, H, W).

  • action_size (int) – Dimension of actions.

  • seq_len (int) – Length of sequences to sample (default: 20).

  • batch_size (int) – Number of sequences per batch (default: 64).

size#

Buffer capacity.

Type:

int

obs_shape#

Observation shape.

Type:

tuple

action_size#

Action dimension.

Type:

int

seq_len#

Sequence length.

Type:

int

batch_size#

Batch size.

Type:

int

steps#

Total transitions added.

Type:

int

episodes#

Number of episode terminations observed.

Type:

int

add(obs, action, reward, terminal)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray) – Observation array with shape (C, H, W).

  • action (ndarray) – Action array with shape (action_size,).

  • reward (float) – Scalar reward value.

  • terminal (bool) – Boolean indicating if episode terminated.

sample_sequence(seq_len=None)[source]#

Sample a batch of sequences for world model training.

Returns:

(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)

Return type:

observations

Parameters:

seq_len (int | None)

sample_single()[source]#

Sample a single transition for online updates.

Return type:

Tuple[ndarray, ndarray, float, float]

property buffer_capacity#

Returns the total capacity of the buffer.

class torchwm.IRISOnPolicyBuffer(max_steps=1000)[source]#

Bases: object

On-policy buffer for collecting trajectories during environment interaction.

Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.

Useful for:
  • Collecting complete episode trajectories

  • Storing data before batch processing

  • Temporary storage during environment interaction

Parameters:

max_steps (int) – Maximum number of steps to store (default: 1000).

max_steps#

Maximum buffer capacity.

Type:

int

observations#

List of observations.

Type:

list

actions#

List of actions.

Type:

list

rewards#

List of rewards.

Type:

list

terminals#

List of terminal flags.

Type:

list

add(obs, action, reward, terminal)[source]#
Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

clear()[source]#
get_arrays()[source]#
class torchwm.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

forward(x, t)[source]#
classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
class torchwm.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.

Process:
  1. Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches

  2. Each patch is projected to embed_dim via linear layer (Conv2d)

  3. Learnable positional embeddings are added for spatial information

Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)

Parameters:
  • img_size – Image size (assumes square), e.g., 32 for CIFAR

  • patch_size – Size of each patch (typically 4, 8, or 16)

  • in_channels – Number of input channels (3 for RGB)

  • embed_dim – Output dimension for each patch token

Usage with DiT:

patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4

forward(x)[source]#
class torchwm.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

forward(x)[source]#
class torchwm.DDPM(timesteps, beta_start, beta_end, device)[source]#

Bases: object

Utility class implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

q_sample(x_start, t, noise=None)[source]#
p_sample(model, x_t, t)[source]#
sample(model, n, img_size, channels)[source]#
class torchwm.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#

Bases: Module

Actor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.

Parameters:
  • obs_channels (int)

  • action_dim (int)

  • channels (Tuple[int, ...])

  • lstm_dim (int)

forward(obs, hidden_state=None)[source]#

Forward pass of actor-critic network.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)

Return type:

policy_logits

get_action(obs, hidden_state=None, deterministic=False)[source]#

Get action from a single observation.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

  • deterministic (bool) – If True, take argmax; else sample

Returns:

Selected action [B] hidden_state: (h, c)

Return type:

action

get_actions(obs, hidden_state=None, deterministic=False)[source]#

Batched version of get_action.

Parameters:
  • obs (Tensor) – Tensor of shape [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size

  • deterministic (bool) – If True, take argmax; else sample from policy

Returns:

LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple

Return type:

actions

get_value(obs, hidden_state=None)[source]#

Get value for a single observation.

Parameters:
  • obs (Tensor)

  • hidden_state (Tuple[Tensor, Tensor] | None)

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor] | None]

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_hidden_size()[source]#

Get LSTM hidden size.

Return type:

int

class torchwm.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#

Bases: Module

Reward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.

Parameters:
  • obs_channels (int) – Number of observation channels (3 for RGB)

  • action_dim (int) – Number of possible actions

  • channels (Tuple[int, ...]) – List of channel sizes for conv blocks

  • lstm_dim (int) – LSTM hidden dimension

  • cond_dim (int) – Conditioning dimension for adaptive norm

forward(obs, actions, hidden_state=None)[source]#

Forward pass of reward/termination model.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • actions (Tensor) – Actions [B, T]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states

Return type:

reward_logits

predict(obs, actions, hidden_state=None)[source]#

Predict reward and termination for a single step.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • actions (Tensor) – Single action [B]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states

Return type:

reward

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

torchwm.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.

Math:

embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)

Parameters:
  • timesteps – Tensor of integer timesteps, shape (B,) or (B, 1)

  • dim – Embedding dimension (must be even)

Returns:

Tensor of shape (B, dim) with sinusoidal embeddings

Usage with DiT:

t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)

# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization

class torchwm.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Spatiotemporal Transformer for video modeling.

Contains L spatiotemporal blocks with interleaved spatial and temporal attention.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

forward(x)[source]#
Parameters:

x (Tensor) – (B, T*N, C) where T is num_frames, N is num_patches_per_frame

Returns:

(B, T*N, C)

Return type:

Tensor

class torchwm.MultiHeadSelfAttention(d, n_heads=2)[source]#

Bases: Module

Multi-head scaled dot-product self-attention over sequence tokens.

This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.

forward(x)[source]#
torchwm.MultiHeadAttention#

alias of MultiHeadSelfAttention

class torchwm.AdaLNNormalization(d_model, t_dim)[source]#

Bases: Module

Adaptive layer normalization conditioned on an external embedding.

The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).

forward(x, t_emb)[source]#
class torchwm.RMSNorm(dim, eps=1e-06)[source]#

Bases: Module

Root Mean Square Layer Normalization with a learned gain parameter.

RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.

forward(x)[source]#
class torchwm.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller using Cross-Entropy Method (CEM) with RSSM.

Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.

Algorithm:
  1. Initialize Gaussian distribution over action sequences

  2. Sample N candidate action sequences

  3. Rollout each sequence in RSSM latent space

  4. Score by predicted cumulative rewards

  5. Keep top K candidates, fit Gaussian to them

  6. Repeat for T iterations

  7. Execute first action from best sequence

Why latent space planning?
  • Images are high-dimensional; latent states are compact

  • Enables thousands of rollouts in parallel

  • Dynamics model is more accurate in latent space

Parameters:
  • model – RSSM instance for latent dynamics

  • planning_horizon – Number of future steps to plan (H)

  • num_candidates – Number of action sequences to sample (N)

  • num_iterations – CEM refinement iterations (T)

  • top_candidates – Number of best candidates to keep (K)

  • device – torch device

Usage with Planet agent:
policy = RSSMPolicy(

model=rssm, planning_horizon=12, num_candidates=1000, num_iterations=8, top_candidates=100, device=’cuda’

)

policy.reset() action = policy.poll(observation) # (1, action_dim)

# For continuous control: next_obs, reward, done, info = env.step(action)

Comparison with Dreamer:
  • RSSMPolicy: Online planning, chooses actions by optimization at each step

  • DreamerActor: Train actor network to predict actions from states

  • Dreamer is more sample-efficient for complex tasks; CEM is more flexible

reset()[source]#
poll(observation, explore=False)[source]#
class torchwm.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Actor network for IRIS (Imagined Rollouts with Implicit Successor) policy.

Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.

Architecture:
  • CNN: Extracts features from input frames (3x64x64 -> 512)

  • LSTM: Processes temporal sequences with configurable layers

  • Linear: Maps hidden states to action logits

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

action_size#

Number of discrete actions.

Type:

int

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

forward(frames, hidden_state=None, burn_in_frames=None)[source]#

Forward pass through actor.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state

  • burn_in_frames (Tensor | None) – Frames to use for initializing hidden state

Returns:

Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple

Return type:

action_logits

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_action(frame, temperature=1.0, deterministic=False)[source]#

Get action from a single frame.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • temperature (float) – Softmax temperature (higher = more random)

  • deterministic (bool) – If True, return argmax; else sample

Returns:

Selected action indices (B,)

Return type:

action

class torchwm.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Critic network for IRIS value estimation.

Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.

Architecture:
  • CNN: Shared feature extractor with actor (3x64x64 -> 512)

  • LSTM: Temporal processing with same architecture as actor

  • Linear: Maps hidden states to scalar values

Parameters:
  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Returns:

Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.

Return type:

values

Parameters:
  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None)[source]#

Forward pass through critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple

Returns:

Value estimates (B, T) hidden_state: Updated (h, c) tuple

Return type:

values

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class torchwm.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Combined policy module for IRIS (Imagined Rollouts with Implicit Successor).

Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

actor#

The actor network for action selection.

Type:

IRISActor

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Example

>>> policy = IRISPolicy(
...     action_size=18,
...     hidden_size=512,
...     num_layers=4,
...     frame_shape=(3, 64, 64)
... )
>>> action = policy.act(frame, temperature=1.0, deterministic=False)
forward(frames)[source]#

Get action logits from frames.

Parameters:

frames (Tensor)

Return type:

Tensor

act(frame, temperature=1.0, deterministic=False)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor)

  • temperature (float)

  • deterministic (bool)

Return type:

Tensor

init_hidden(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

class torchwm.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#

Bases: Module

CNN feature extractor shared between actor and critic networks.

Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.

Architecture:
  • Conv layers: 32 -> 64 -> 128 -> 256 channels

  • Each conv has stride=2 for spatial downsampling

  • Final linear layer projects to desired output dimension

Parameters:
  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

  • output_size (int) – Size of output feature vector (default: 512).

frame_shape#

Input frame shape.

Type:

tuple

output_size#

Output feature dimension.

Type:

int

Returns:

Feature vectors with shape (B, output_size).

Return type:

features

Parameters:
  • frame_shape (Tuple[int, int, int])

  • output_size (int)

forward(x)[source]#

Extract features from frames.

Parameters:

x (Tensor) – Frames (B, C, H, W)

Returns:

Feature vectors (B, output_size)

Return type:

features

class torchwm.DreamerConfig[source]#

Bases: object

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

class torchwm.JEPAConfig[source]#

Bases: object

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

class torchwm.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: object

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
torchwm.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)

class torchwm.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: object

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#
class torchwm.IRISConfig[source]#

Bases: object

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
get_autoencoder_config()[source]#
get_transformer_config()[source]#
get_rl_config()[source]#
class torchwm.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: object

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class torchwm.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#

Bases: object

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
class torchwm.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class torchwm.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: object

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class torchwm.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: object

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class torchwm.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class torchwm.OperatorABC[source]#

Bases: ABC

Abstract base class for operators that preprocess inputs for inference pipelines.

abstractmethod process(inputs)[source]#

Process raw inputs into standardized tensor format for model consumption.

Parameters:

inputs (Any) – Raw input data (dict, tensor, or other formats)

Returns:

Dict of processed tensors ready for model input

Return type:

Dict[str, Tensor]

class torchwm.DreamerOperator(image_size=64, action_dim=6)[source]#

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

process(inputs)[source]#

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class torchwm.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

process(inputs)[source]#

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class torchwm.IrisOperator(seq_length=512, vocab_size=32000)[source]#

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

process(inputs)[source]#

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class torchwm.PlaNetOperator(state_dim=32, action_dim=4)[source]#

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

process(inputs)[source]#

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

torchwm.get_operator(name, **kwargs)[source]#

Factory function to get inference operators by name.

Parameters:
  • name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’

  • **kwargs – Operator-specific configuration

Returns:

Configured OperatorABC instance

Example

>>> op = get_operator('dreamer', image_size=64, action_dim=6)
>>> processed = op.process({'image': image, 'action': action})
class torchwm.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Predict scalar rewards from Dreamer latent belief and state vectors.

Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.

forward(belief, state)[source]#
class torchwm.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Estimate scalar value from Dreamer latent belief and state vectors.

This MLP is trained on imagined returns and used for actor/value updates.

forward(belief, state)[source]#

TorchWM public API.

This package keeps imports lightweight while still exposing a friendly top-level surface. Common workflows can use the small factory helpers:

import torchwm

cfg = torchwm.create_config("dreamer", env="walker-walk")
agent = torchwm.create_model("dreamer", cfg)
env = torchwm.make_env("CartPole-v1", backend="gym")

Lower-level research components remain available as lazy top-level exports, for example from torchwm import DreamerAgent, ConvEncoder, ReplayBuffer.

class world_models.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing an environment backend available through make_env.

Parameters:
  • name (str)

  • factory_path (str)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

factory_path: str#

Alias for field number 1

description: str#

Alias for field number 2

aliases: tuple[str, ...]#

Alias for field number 3

class world_models.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing a model available through create_model().

Parameters:
  • name (str)

  • import_path (str)

  • config_path (str | None)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

import_path: str#

Alias for field number 1

config_path: str | None#

Alias for field number 2

description: str#

Alias for field number 3

aliases: tuple[str, ...]#

Alias for field number 4

world_models.create_config(model, **overrides)[source]#

Create the default config object for model and apply overrides.

Examples

>>> cfg = create_config("dreamer", env="walker-walk", seed=7)
>>> cfg.env
'walker-walk'
Parameters:
  • model (str)

  • overrides (Any)

Return type:

Any

world_models.create_model(model, config=None, **overrides)[source]#

Instantiate a model or agent from a simple string name.

config is optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.

Examples

>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000)
>>> genie = create_model("genie-small", image_size=32)
Parameters:
  • model (str)

  • config (Any | None)

  • overrides (Any)

Return type:

Any

world_models.get_env_backend_spec(name)[source]#

Return metadata for an environment backend name or alias.

Parameters:

name (str)

Return type:

EnvBackendSpec

world_models.get_model_spec(name)[source]#

Return metadata for a model name or alias.

Parameters:

name (str)

Return type:

ModelSpec

world_models.list_env_backends()[source]#

Return canonical backend names accepted by make_env().

Return type:

list[str]

world_models.list_envs(model=None)[source]#

List known environment ids, optionally filtered by model family.

Parameters:

model (str | None)

Return type:

list[str] | dict[str, list[str]]

world_models.list_models()[source]#

Return canonical model names accepted by create_model().

Return type:

list[str]

world_models.make_env(env_id, backend='auto', **kwargs)[source]#

Create an environment with a consistent TorchWM entry point.

Parameters:
  • env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.

  • backend (str) – One of list_env_backends(); "auto" tries TorchWM’s compatibility helper.

  • **kwargs (Any) – Backend-specific options.

Return type:

Any

class world_models.IRISAgent(config, action_size, device)[source]#

Bases: Module

Complete IRIS Agent with world model and policy.

Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning

Parameters:
  • config (IRISConfig)

  • action_size (int)

  • device (device)

forward_actor_critic(frames, hidden=None)[source]#

Forward pass through actor-critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state

Returns:

(B, T, action_size) values: (B, T) hidden_state: (h, c)

Return type:

action_logits

act(frame, epsilon=0.0, temperature=1.0)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • epsilon (float) – Random action probability

  • temperature (float) – Action distribution temperature

Returns:

Selected actions (B,)

Return type:

actions

imagine_rollout(initial_frame, horizon=20)[source]#

Generate imagined trajectories using world model.

Parameters:
  • initial_frame (Tensor) – Starting frame (B, C, H, W)

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined rollout data

Return type:

trajectory

update_autoencoder(frames)[source]#

Update discrete autoencoder.

Parameters:

frames (Tensor) – Training frames (B, C, H, W)

Returns:

Dictionary of loss values

Return type:

losses

update_transformer(frames, actions, rewards, terminals)[source]#

Update transformer world model.

Parameters:
  • frames (Tensor) – Frame sequence

  • actions (Tensor) – Actions taken

  • rewards (Tensor) – Rewards received

  • terminals (Tensor) – Terminal flags

Returns:

Dictionary of loss values

Return type:

losses

update_actor_critic(imagined_trajectory)[source]#

Update actor-critic in imagination.

Parameters:

imagined_trajectory (dict) – Dictionary from imagine_rollout

Returns:

Dictionary of loss values

Return type:

losses

save(path)[source]#

Save agent state.

Parameters:

path (str)

load(path)[source]#

Load agent state.

Parameters:

path (str)

world_models.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#

Compute λ-return target for value function training.

Parameters:
  • rewards (Tensor) – Rewards (B, T)

  • values (Tensor) – Value estimates (B, T+1)

  • discounts (Tensor) – Discount factors (B, T)

  • lambda_coef (float) – Lambda parameter for bootstrapping

Returns:

λ-return targets (B, T)

Return type:

lambda_returns

class world_models.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#

Bases: Module

Modular RSSM with swappable encoder, decoder, and backbone.

This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer

Example

>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024)
>>> decoder = ConvDecoder(32, 200, (3, 64, 64))
>>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024)
>>> rssm = ModularRSSM(encoder, decoder, backbone)
Parameters:
property stoch_size: int#
property deter_size: int#
property embed_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

get_dist(mean, std)[source]#
Parameters:
  • mean (Tensor)

  • std (Tensor)

Return type:

Distribution

observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • obs (Tensor)

  • nonterm (Any)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • nonterm (Any)

Return type:

Dict[str, Tensor]

observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
Parameters:
  • obs (Tensor)

  • actions (Tensor)

  • nonterms (Tensor)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_rollout(actor, prev_state, horizon)[source]#
Parameters:
  • actor (Module)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Dict[str, Tensor]

decode_observation(features)[source]#
Parameters:

features (Tensor)

decode_reward(features)[source]#
Parameters:

features (Tensor)

detach_state(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

seq_to_batch(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

world_models.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#

Factory function to create a modular RSSM with specified components.

Parameters:
  • encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)

  • decoder_type (str) – Type of decoder (“conv”, “mlp”)

  • backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)

  • obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state

  • action_size (int) – Action space dimension

  • stoch_size (int) – Stochastic latent dimension

  • deter_size (int) – Deterministic hidden dimension

  • embed_size (int) – Encoder embedding dimension

  • hidden_size (int) – Hidden layer dimension

  • activation (str) – Activation function name

Returns:

Configured ModularRSSM instance

Return type:

ModularRSSM

class world_models.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Genie: Generative Interactive Environment.

A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions

Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).

Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)

The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse

At inference, latent actions are stopgrad’d when passed to dynamics model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_decoder_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • latent_action_depth (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

forward(video, mask_prob=0.5, training_phase='all')[source]#

Full forward pass through all components.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing losses and predictions

Return type:

Dict[str, Tensor]

training_step(video, mask_prob=0.5, training_phase='all')[source]#

Single training step computing all losses.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing all losses for backpropagation

Return type:

Dict[str, Tensor]

encode_video(video)[source]#

Encode video to discrete tokens.

Parameters:

video (Tensor) – (B, C, T, H, W)

Returns:

(B, T, H*W)

Return type:

video_tokens

infer_actions(frames)[source]#

Infer latent actions from a sequence of frames.

Parameters:

frames (Tensor) – (B, C, T, H, W) video frames

Returns:

(B, T-1) inferred latent action indices

Return type:

latent_actions

generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#

Generate video frames given a prompt frame and actions.

Parameters:
  • prompt_frame (Tensor) – (B, C, H, W) initial frame

  • num_frames (int) – Total number of frames to generate

  • actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random

  • use_maskgit (bool) – Whether to use MaskGIT sampling

Returns:

(B, C, num_frames, H, W)

Return type:

generated_video

play(current_frame, action, current_frames=None)[source]#

Play step - generate next frame given current frame and action.

Parameters:
  • current_frame (Tensor) – (B, C, H, W) current frame

  • action (Tensor) – (B,) latent action indices

  • current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame

Returns:

(B, C, H, W)

Return type:

next_frame

get_num_parameters()[source]#

Return total number of parameters.

Return type:

int

class world_models.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Latent Action Model (LAM) for unsupervised action learning.

Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.

Based on Genie paper - learns actions without action labels from Internet videos.

Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

encode(x_prev, x_next)[source]#

Encode frames to latent actions.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)

Return type:

latent_actions

decode(x_prev, z_q)[source]#

Decode latent actions to predict next frame.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W) - will mask all but first

  • z_q (Tensor) – Quantized action embeddings (B, T-1, embedding_dim)

Returns:

(B, C, H, W)

Return type:

predicted_next_frame

forward(x_prev, x_next)[source]#

Full forward pass: encode to get actions, decode to reconstruct.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Dictionary with losses and outputs

Return type:

Dict[str, Tensor]

class world_models.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: Module

Dynamics Model for action-controllable video generation.

A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.

Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

forward(video_tokens, actions, mask_prob=0.0)[source]#

Forward pass for training.

Parameters:
  • video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T

  • actions (Tensor) – (B, T) - latent action indices for frames 1 to T

  • mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)

Returns:

(B, T, H*W, vocab_size)

Return type:

logits

sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#

Sample future frames using MaskGIT.

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • sampler (MaskGITSampler | None) – MaskGIT sampler instance

Returns:

(B, num_frames, N)

Return type:

generated_tokens

autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#

Simple autoregressive sampling (token by token).

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • temperature (float) – Sampling temperature

Returns:

(B, num_frames, N)

Return type:

generated_tokens

world_models.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Create a smaller Genie model for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#

Create the full 11B parameter Genie model (approximate).

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

LatentActionModel

world_models.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#

Factory function to create a Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

Return type:

DynamicsModel

class world_models.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer for latent dynamics learning.

The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:

  1. Deterministic State (h): A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.

  2. Stochastic State (s): A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).

The model operates in two modes:

  • Observe Mode: Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)

  • Imagine Mode: Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)

Architecture:

Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1} Process: GRU updates deterministic state, MLP computes stochastic prior/posterior Output: Updated state (h_t, s_t) and distributions

State Representation:
  • deter (h): GRU hidden state, captures sequential context

  • stoch (s): Stochastic latent, multi-modal uncertainty

  • mean/std: Parameters of the stochastic distribution

Usage with DreamerAgent:
rssm = RSSM(

action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation=’elu’

)

# Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed)

# Imagine future without observation prior = rssm.imagine_step(current_state, action)

Training:

The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to

capture environment dynamics

  • Reconstruction loss from decoder ensures state captures observation info

Reference:

Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603

init_state(batch_size, device)[source]#

Initialize RSSM state with zeros.

Parameters:
  • batch_size – Number of parallel sequences

  • device – torch device for tensors

Returns:

  • mean, std: Stochastic distribution parameters

  • stoch: Stochastic state sample

  • deter: Deterministic GRU hidden state

Return type:

Dictionary containing zero-initialized state components

get_dist(mean, std)[source]#

Create an Independent Normal distribution from mean and std.

Parameters:
  • mean – Location parameter

  • std – Scale parameter

Returns:

Independent Normal distribution with given parameters

observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#

Update state using actual observation (observe mode).

In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.

Parameters:
  • prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action – Previous action a_{t-1}, shape (B, action_size)

  • obs_embed – Observation embedding from encoder, shape (B, obs_embed_size)

  • nonterm – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

A tuple (posterior, prior) of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#

Predict next state without observation (imagine mode).

In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.

Parameters:
  • prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action – Previous action a_{t-1}, shape (B, action_size)

  • nonterm – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

  • deter: Predicted deterministic state

  • mean, std, stoch: Prior stochastic state distribution

Return type:

Dictionary with predicted state containing

get_prior(prev_state, prev_action, nonterm=1.0)[source]#

Compute prior distribution over stochastic state.

The prior represents the model’s belief about the stochastic state before observing the actual outcome.

Parameters:
  • prev_state – Previous state dictionary

  • prev_action – Previous action

  • nonterm – Termination mask

Returns:

Dictionary with prior state (no observation information)

get_posterior(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#

Compute posterior distribution over stochastic state.

The posterior incorporates observation information to produce a more accurate state estimate.

Parameters:
  • prev_state – Previous state dictionary

  • prev_action – Previous action

  • obs_embed – Observation embedding

  • nonterm – Termination mask

Returns:

Dictionary with posterior state (observation-informed). Note that the previous-state shape (B, ...) is preserved; the batch dimension is not flattened.

detach_state(state)[source]#

Detach state tensors from computation graph.

Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.

Parameters:

state – State dictionary with tensor values

Returns:

Detached state dictionary

seq_to_batch(state_dict)[source]#

Convert sequence state to batch format.

Parameters:

state_dict – Dictionary with sequence-dimension tensors (T, B, …)

Returns:

Dictionary with batch-dimension tensors (B*T, …)

observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#

Process a sequence of observations (observe mode rollout).

At each timestep we run observe_step once to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.

Parameters:
  • obs_embed – Observation embeddings, shape (T+1, B, obs_embed_size)

  • actions – Actions, shape (T, B, action_size)

  • nonterms – Non-termination flags, shape (T, B, 1)

  • init_state – Initial state dictionary

  • seq_len – Sequence length T

Returns:

Dictionary with prior states stacked along the time axis. posterior: Dictionary with posterior states stacked along the time

axis.

Return type:

prior

imagine_rollout(policy, init_state, horizon)[source]#

Generate imagined trajectory using policy (imagine mode rollout).

Parameters:
  • policy – Actor network that outputs actions from state features

  • init_state – Initial state dictionary

  • horizon – Number of steps to imagine

Returns:

Dictionary with imagined states for each step

forward(x, u)[source]#

Forward pass for training (computes sequence of states).

Parameters:
  • x – Observations, shape (B, T+1, C, H, W)

  • u – Actions, shape (B, T, action_size)

Returns:

List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)

Return type:

states

class world_models.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#

Bases: Module

A Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.

get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#

Returns the initial posterior given the observation.

deterministic_state_fwd(h_t, s_t, a_t)[source]#
Deterministic transition update that accepts:
  • a_t shaped [B, action_size]

  • a_t shaped [action_size] (unbatched) -> expanded to [B, action_size]

  • a_t shaped [B] or scalar -> reshaped appropriately

Ensures a_t is 2D and matches batch dimension of h_t before concatenation.

state_prior(h_t, sample=False)[source]#

Returns the prior distribution over the latent state given the deterministic state

state_posterior(h_t, e_t, sample=False)[source]#

Returns the state prior given the deterministic state and obs

pred_reward(h_t, s_t)[source]#
rollout_prior(act, h_t, s_t)[source]#
forward(x, u)[source]#

Forward through the RSSM for a batch of sequences. Inputs:

x: Tensor [B, T+1, C, H, W] (observations including initial frame) u: Tensor [B, T, action_size] (actions for T steps)

Returns:

list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]

Return type:

states

class world_models.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).

Architecture:

Input: (B, C, H, W) raw images, values in [-0.5, 0.5] Process: 4 convolutional layers with stride 2, halving spatial dimensions Output: (B, embed_size) compact representation

The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.

Usage with Dreamer:
encoder = ConvEncoder(

input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation=’relu’ # Activation function

) obs_embedding = encoder(observation) # (B, 256)

Parameters:
  • input_shape – Tuple (C, H, W) for input images, typically (3, 64, 64)

  • embed_size – Output embedding dimension, typically 256 or 1024

  • activation – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)

  • depth – Base channel depth for first layer (default 32)

forward(inputs)[source]#
class world_models.CNNEncoder(embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) encoder for processing image inputs.

forward(observation)[source]#
class world_models.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional decoder for reconstructing observations from latent states.

Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.

Architecture:

Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter) Process: Dense projection + 4 transposed convolutions (upsampling 2x each) Output: Independent Normal distribution over observation pixels

The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.

Output Distribution:

Returns torch.distributions.Independent(Normal(mean, std), len(shape)) This allows computing log_prob(observation) for reconstruction loss.

Usage in Dreamer world model:
decoder = ConvDecoder(

stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation=’relu’

) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)

Training:

The reconstruction loss is: -log_prob(observation) This encourages the RSSM to learn states that capture observation information.

forward(features)[source]#
class world_models.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) decoder for reconstructing image outputs.

forward(latent, state)[source]#
class world_models.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.

Architecture:

Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter) Process: MLP with configurable layers and hidden units Output: Predicted quantity with distribution (normal, binary, or raw)

Supports three output types:
  • ‘normal’: Gaussian distribution for regression (rewards, values)

  • ‘binary’: Bernoulli distribution for binary classification (discount)

  • ‘none’: Raw tensor for non-probabilistic outputs

Usage:
reward_decoder = DenseDecoder(

stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’normal’

) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)

For discount prediction (binary):
discount_decoder = DenseDecoder(

stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’binary’ # Bernoulli for P(continue)

)

forward(features)[source]#
class world_models.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

forward(features, deter=False)[source]#
add_exploration(action, action_noise=0.3)[source]#
class world_models.TanhBijector[source]#

Bases: Transform

Bijective tanh transform for squashing Gaussian distributions to [-1, 1].

This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:

  1. Bijective mapping: tanh is invertible (with atanh as inverse)

  2. Stable log-det Jacobian: Computable for gradient-based training

  3. Clipped actions: During inference, actions are naturally bounded

Math:

Forward: y = tanh(x) Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y)) Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))

Usage with Dreamer ActionDecoder:
dist = TransformedDistribution(

Normal(mean, std), TanhBijector()

) action = dist.sample() # Bounded to [-1, 1]

Reference:

Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.

property sign#
atanh(x)[source]#
log_abs_det_jacobian(x, y)[source]#
class world_models.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

property name#
mean()[source]#
mode()[source]#
entropy()[source]#
sample()[source]#
class world_models.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Encoder for IRIS discrete autoencoder.

Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.

Architecture:
  • 4 convolutional layers with residual blocks

  • Self-attention at 8x8 and 16x16 resolutions

  • Vector quantization to produce discrete tokens

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • in_channels (int)

  • base_channels (int)

  • num_residual_blocks (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Encode images to discrete tokens.

Parameters:

x (Tensor) – Input images (B, C, H, W) - should be 64x64

Returns:

Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

encode_to_indices(x)[source]#

Encode directly to token indices (for world model).

Parameters:

x (Tensor)

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode token indices to embeddings (for decoder).

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Decoder for IRIS discrete autoencoder.

Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • base_channels (int)

  • out_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(z)[source]#

Decode tokens to images.

Parameters:

z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)

Returns:

Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)

Return type:

reconstructed

decode_from_embeddings(z_flat)[source]#

Decode flattened token embeddings to images.

Parameters:

z_flat (Tensor) – Flattened tokens (B, H*W, C) or (B, C, H, W)

Returns:

Reconstructed images

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode discrete token indices into images.

Parameters:

indices (Tensor) – Tensor of shape (B, H, W) or (B, H*W) containing integer token indices in the range [0, vocab_size).

Returns:

Reconstructed images (B, C, H, W)

Return type:

Tensor

class world_models.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#

Bases: Module

Video Tokenizer using VQ-VAE with Spatiotemporal Transformer.

This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.

The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.

Architecture:
  1. Patch Embedding: Convert (B, C, T, H, W) video to patch tokens

  2. Encoder ST-Transformer: Process spatial-temporal patches

  3. Vector Quantization: Discretize continuous embeddings to codebook entries

  4. Decoder ST-Transformer: Reconstruct video from quantized tokens

  5. Patch Unembedding: Convert tokens back to video frames

Key Features:
  • Causal processing: Each frame’s encoding only uses previous frames

  • Discrete tokens: Enables autoregressive prediction with latent actions

  • Memory efficient: Uses ST-Transformer instead of full ViT to reduce O(n²) complexity

Usage with Genie:
tokenizer = VideoTokenizer(

num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32

) reconstructed, indices, loss_dict = tokenizer(video_frames)

# For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)

Training:

The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings (encourages learning useful codes) - Commitment loss: Penalizes encoder outputs drifting from codebook

Reference:

Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • use_ema (bool)

  • ema_decay (float)

encode(x)[source]#

Encode video to discrete tokens.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices to embeddings for video frames.

Parameters:

indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’*W’

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim)

Return type:

z_q

decode(z_q)[source]#

Decode discrete tokens to video frames.

Parameters:

z_q (Tensor) – Quantized embeddings (B, T, H’, W’, embedding_dim)

Returns:

Reconstructed video (B, C, T, H, W)

Return type:

Tensor

forward(x)[source]#

Full forward pass with VQ-VAE objective.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Reconstructed video (B, C, T, H, W) indices: Token indices (B, T, H’, W’) loss_dict: Dictionary containing loss components

Return type:

reconstructed

world_models.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#

Factory function to create a Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

Return type:

VideoTokenizer

class world_models.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#

Bases: Module

Vector Quantizer for discrete autoencoder.

Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)

Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

forward(z)[source]#

Quantize the input latents.

Parameters:

z (Tensor) – Input tensor of shape (B, C, H, W) or (B, C)

Returns:

Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices back to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#

Bases: Module

Vector Quantizer with Exponential Moving Average updates.

Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • ema_decay (float)

  • epsilon (float)

forward(z)[source]#

Quantize with EMA updates.

Parameters:

z (Tensor)

Return type:

Tuple[Tensor, Tensor, dict[str, Tensor]]

decode_indices(indices)[source]#

Decode token indices to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer with image observations and transitions.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.

Key Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores raw uint8 images to save memory

  • Samples sequences (not single transitions) for temporal modeling

  • Validates sampled sequences don’t span episode boundaries

Memory Layout:
  • observations: (capacity, C, H, W) uint8 images

  • actions: (capacity, action_dim) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)

Sampling Process:
  1. Random start index (avoiding episode boundaries)

  2. Collect sequence of length seq_len with wraparound

  3. Validate no terminal in middle of sequence

  4. Return batch of sequences

Usage with Dreamer:
buffer = ReplayBuffer(

size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch

)

# Add transitions during interaction buffer.add(obs, action, reward, done)

# Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample() # Shapes: (seq_len, batch, C, H, W), (seq_len, batch, action_dim), etc.

Memory Efficiency:
  • Uses uint8 for images (1 byte per pixel vs 4 for float32)

  • Sequences share observations (overlapping windows)

  • Configurable capacity based on available system memory

Note

The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.

Parameters:
  • size (int)

  • obs_shape (Tuple[int, ...])

  • action_size (int)

  • seq_len (int)

  • batch_size (int)

add(obs, ac, rew, done)[source]#

Add a transition to the buffer.

Parameters:
  • obs (dict) – Observation dict with ‘image’ key containing the observation

  • ac (ndarray) – Action taken, shape (action_size,)

  • rew (float) – Reward received, scalar

  • done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise

Return type:

None

sample()[source]#

Sample a batch of sequences for training.

Returns:

(observations, actions, rewards, terminals)
  • observations: (seq_len, batch, C, H, W)

  • actions: (seq_len, batch, action_dim)

  • rewards: (seq_len, batch)

  • terminals: (seq_len, batch)

Return type:

tuple

class world_models.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.

Features:
  • Stores complete episodes as lists of transitions

  • Samples contiguous sub-sequences for sequence models

  • Supports time-major formatting (time-first) for RNN input

  • Memory usage estimation to prevent OOM errors

Parameters:

size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).

episodes#

Collection of Episode objects.

Type:

deque

eps_lengths#

Length of each episode.

Type:

deque

size#

Total number of transitions across all episodes.

Type:

property

Example

>>> memory = Memory(size=100)
>>> memory.append([episode1, episode2])
>>> batch, lengths = memory.sample(batch_size=32, tracelen=50)
property size#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

sample(batch_size, tracelen=1, time_first=False)[source]#

Sample random sub-sequences from stored episodes.

Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.

Parameters:
  • batch_size (int) – Number of sequences to sample.

  • tracelen (int) – Length of each sequence (default: 1).

  • time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).

Returns:

(observations, actions, rewards, terminals, lengths)
  • observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)

  • actions: (batch, tracelen, action_dim) or (tracelen, batch, …)

  • rewards: (batch, tracelen) or (tracelen, batch)

  • terminals: (batch, tracelen) or (tracelen, batch)

  • lengths: (batch,) original episode lengths for each sample

Return type:

tuple

Raises:
  • ValueError – If memory is empty or no episodes meet minimum length.

  • MemoryError – If estimated memory usage exceeds 200 MiB threshold.

class world_models.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode.

Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.

x#

Observations collected during the episode.

Type:

list or np.ndarray

u#

Actions taken.

Type:

list or np.ndarray

r#

Rewards received.

Type:

list or np.ndarray

t#

Terminal flags (0.0 = continue, 1.0 = terminal).

Type:

list or np.ndarray

info#

Additional episode metadata.

Type:

dict

Parameters:

postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.

Example

>>> episode = Episode()
>>> episode.append(obs, action, reward, False)
>>> episode.append(obs, action, reward, True)
>>> episode.terminate(final_obs)
>>> print(episode.x.shape)  # Now a numpy array
property size#
append(obs, act, reward, terminal)[source]#
terminate(obs)[source]#
class world_models.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#

Bases: object

Replay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.

Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores uint8 images for memory efficiency

  • Samples sequences with validation to avoid episode boundaries

  • Supports sequence sampling for temporal learning

Memory Layout:
  • observations: (capacity, C, H, W) uint8

  • actions: (capacity, action_size) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32

Parameters:
  • size (int) – Maximum number of transitions to store.

  • obs_shape (tuple) – Shape of observations as (C, H, W).

  • action_size (int) – Dimension of actions.

  • seq_len (int) – Length of sequences to sample (default: 20).

  • batch_size (int) – Number of sequences per batch (default: 64).

size#

Buffer capacity.

Type:

int

obs_shape#

Observation shape.

Type:

tuple

action_size#

Action dimension.

Type:

int

seq_len#

Sequence length.

Type:

int

batch_size#

Batch size.

Type:

int

steps#

Total transitions added.

Type:

int

episodes#

Number of episode terminations observed.

Type:

int

add(obs, action, reward, terminal)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray) – Observation array with shape (C, H, W).

  • action (ndarray) – Action array with shape (action_size,).

  • reward (float) – Scalar reward value.

  • terminal (bool) – Boolean indicating if episode terminated.

sample_sequence(seq_len=None)[source]#

Sample a batch of sequences for world model training.

Returns:

(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)

Return type:

observations

Parameters:

seq_len (int | None)

sample_single()[source]#

Sample a single transition for online updates.

Return type:

Tuple[ndarray, ndarray, float, float]

property buffer_capacity#

Returns the total capacity of the buffer.

class world_models.IRISOnPolicyBuffer(max_steps=1000)[source]#

Bases: object

On-policy buffer for collecting trajectories during environment interaction.

Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.

Useful for:
  • Collecting complete episode trajectories

  • Storing data before batch processing

  • Temporary storage during environment interaction

Parameters:

max_steps (int) – Maximum number of steps to store (default: 1000).

max_steps#

Maximum buffer capacity.

Type:

int

observations#

List of observations.

Type:

list

actions#

List of actions.

Type:

list

rewards#

List of rewards.

Type:

list

terminals#

List of terminal flags.

Type:

list

add(obs, action, reward, terminal)[source]#
Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

clear()[source]#
get_arrays()[source]#
class world_models.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

forward(x, t)[source]#
classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
class world_models.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.

Process:
  1. Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches

  2. Each patch is projected to embed_dim via linear layer (Conv2d)

  3. Learnable positional embeddings are added for spatial information

Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)

Parameters:
  • img_size – Image size (assumes square), e.g., 32 for CIFAR

  • patch_size – Size of each patch (typically 4, 8, or 16)

  • in_channels – Number of input channels (3 for RGB)

  • embed_dim – Output dimension for each patch token

Usage with DiT:

patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4

forward(x)[source]#
class world_models.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

forward(x)[source]#
class world_models.DDPM(timesteps, beta_start, beta_end, device)[source]#

Bases: object

Utility class implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

q_sample(x_start, t, noise=None)[source]#
p_sample(model, x_t, t)[source]#
sample(model, n, img_size, channels)[source]#
class world_models.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#

Bases: Module

Actor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.

Parameters:
  • obs_channels (int)

  • action_dim (int)

  • channels (Tuple[int, ...])

  • lstm_dim (int)

forward(obs, hidden_state=None)[source]#

Forward pass of actor-critic network.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)

Return type:

policy_logits

get_action(obs, hidden_state=None, deterministic=False)[source]#

Get action from a single observation.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

  • deterministic (bool) – If True, take argmax; else sample

Returns:

Selected action [B] hidden_state: (h, c)

Return type:

action

get_actions(obs, hidden_state=None, deterministic=False)[source]#

Batched version of get_action.

Parameters:
  • obs (Tensor) – Tensor of shape [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size

  • deterministic (bool) – If True, take argmax; else sample from policy

Returns:

LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple

Return type:

actions

get_value(obs, hidden_state=None)[source]#

Get value for a single observation.

Parameters:
  • obs (Tensor)

  • hidden_state (Tuple[Tensor, Tensor] | None)

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor] | None]

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_hidden_size()[source]#

Get LSTM hidden size.

Return type:

int

class world_models.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#

Bases: Module

Reward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.

Parameters:
  • obs_channels (int) – Number of observation channels (3 for RGB)

  • action_dim (int) – Number of possible actions

  • channels (Tuple[int, ...]) – List of channel sizes for conv blocks

  • lstm_dim (int) – LSTM hidden dimension

  • cond_dim (int) – Conditioning dimension for adaptive norm

forward(obs, actions, hidden_state=None)[source]#

Forward pass of reward/termination model.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • actions (Tensor) – Actions [B, T]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states

Return type:

reward_logits

predict(obs, actions, hidden_state=None)[source]#

Predict reward and termination for a single step.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • actions (Tensor) – Single action [B]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states

Return type:

reward

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

world_models.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.

Math:

embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)

Parameters:
  • timesteps – Tensor of integer timesteps, shape (B,) or (B, 1)

  • dim – Embedding dimension (must be even)

Returns:

Tensor of shape (B, dim) with sinusoidal embeddings

Usage with DiT:

t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)

# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization

class world_models.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Spatiotemporal Transformer for video modeling.

Contains L spatiotemporal blocks with interleaved spatial and temporal attention.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

forward(x)[source]#
Parameters:

x (Tensor) – (B, T*N, C) where T is num_frames, N is num_patches_per_frame

Returns:

(B, T*N, C)

Return type:

Tensor

class world_models.MultiHeadSelfAttention(d, n_heads=2)[source]#

Bases: Module

Multi-head scaled dot-product self-attention over sequence tokens.

This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.

forward(x)[source]#
world_models.MultiHeadAttention#

alias of MultiHeadSelfAttention

class world_models.AdaLNNormalization(d_model, t_dim)[source]#

Bases: Module

Adaptive layer normalization conditioned on an external embedding.

The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).

forward(x, t_emb)[source]#
class world_models.RMSNorm(dim, eps=1e-06)[source]#

Bases: Module

Root Mean Square Layer Normalization with a learned gain parameter.

RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.

forward(x)[source]#
class world_models.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller using Cross-Entropy Method (CEM) with RSSM.

Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.

Algorithm:
  1. Initialize Gaussian distribution over action sequences

  2. Sample N candidate action sequences

  3. Rollout each sequence in RSSM latent space

  4. Score by predicted cumulative rewards

  5. Keep top K candidates, fit Gaussian to them

  6. Repeat for T iterations

  7. Execute first action from best sequence

Why latent space planning?
  • Images are high-dimensional; latent states are compact

  • Enables thousands of rollouts in parallel

  • Dynamics model is more accurate in latent space

Parameters:
  • model – RSSM instance for latent dynamics

  • planning_horizon – Number of future steps to plan (H)

  • num_candidates – Number of action sequences to sample (N)

  • num_iterations – CEM refinement iterations (T)

  • top_candidates – Number of best candidates to keep (K)

  • device – torch device

Usage with Planet agent:
policy = RSSMPolicy(

model=rssm, planning_horizon=12, num_candidates=1000, num_iterations=8, top_candidates=100, device=’cuda’

)

policy.reset() action = policy.poll(observation) # (1, action_dim)

# For continuous control: next_obs, reward, done, info = env.step(action)

Comparison with Dreamer:
  • RSSMPolicy: Online planning, chooses actions by optimization at each step

  • DreamerActor: Train actor network to predict actions from states

  • Dreamer is more sample-efficient for complex tasks; CEM is more flexible

reset()[source]#
poll(observation, explore=False)[source]#
class world_models.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Actor network for IRIS (Imagined Rollouts with Implicit Successor) policy.

Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.

Architecture:
  • CNN: Extracts features from input frames (3x64x64 -> 512)

  • LSTM: Processes temporal sequences with configurable layers

  • Linear: Maps hidden states to action logits

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

action_size#

Number of discrete actions.

Type:

int

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

forward(frames, hidden_state=None, burn_in_frames=None)[source]#

Forward pass through actor.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state

  • burn_in_frames (Tensor | None) – Frames to use for initializing hidden state

Returns:

Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple

Return type:

action_logits

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_action(frame, temperature=1.0, deterministic=False)[source]#

Get action from a single frame.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • temperature (float) – Softmax temperature (higher = more random)

  • deterministic (bool) – If True, return argmax; else sample

Returns:

Selected action indices (B,)

Return type:

action

class world_models.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Critic network for IRIS value estimation.

Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.

Architecture:
  • CNN: Shared feature extractor with actor (3x64x64 -> 512)

  • LSTM: Temporal processing with same architecture as actor

  • Linear: Maps hidden states to scalar values

Parameters:
  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Returns:

Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.

Return type:

values

Parameters:
  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None)[source]#

Forward pass through critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple

Returns:

Value estimates (B, T) hidden_state: Updated (h, c) tuple

Return type:

values

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class world_models.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Combined policy module for IRIS (Imagined Rollouts with Implicit Successor).

Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

actor#

The actor network for action selection.

Type:

IRISActor

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Example

>>> policy = IRISPolicy(
...     action_size=18,
...     hidden_size=512,
...     num_layers=4,
...     frame_shape=(3, 64, 64)
... )
>>> action = policy.act(frame, temperature=1.0, deterministic=False)
forward(frames)[source]#

Get action logits from frames.

Parameters:

frames (Tensor)

Return type:

Tensor

act(frame, temperature=1.0, deterministic=False)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor)

  • temperature (float)

  • deterministic (bool)

Return type:

Tensor

init_hidden(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

class world_models.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#

Bases: Module

CNN feature extractor shared between actor and critic networks.

Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.

Architecture:
  • Conv layers: 32 -> 64 -> 128 -> 256 channels

  • Each conv has stride=2 for spatial downsampling

  • Final linear layer projects to desired output dimension

Parameters:
  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

  • output_size (int) – Size of output feature vector (default: 512).

frame_shape#

Input frame shape.

Type:

tuple

output_size#

Output feature dimension.

Type:

int

Returns:

Feature vectors with shape (B, output_size).

Return type:

features

Parameters:
  • frame_shape (Tuple[int, int, int])

  • output_size (int)

forward(x)[source]#

Extract features from frames.

Parameters:

x (Tensor) – Frames (B, C, H, W)

Returns:

Feature vectors (B, output_size)

Return type:

features

class world_models.DreamerConfig[source]#

Bases: object

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

class world_models.JEPAConfig[source]#

Bases: object

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

class world_models.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: object

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)

class world_models.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: object

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#
class world_models.IRISConfig[source]#

Bases: object

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
get_autoencoder_config()[source]#
get_transformer_config()[source]#
get_rl_config()[source]#
class world_models.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: object

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#

Bases: object

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
class world_models.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: object

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class world_models.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: object

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class world_models.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.OperatorABC[source]#

Bases: ABC

Abstract base class for operators that preprocess inputs for inference pipelines.

abstractmethod process(inputs)[source]#

Process raw inputs into standardized tensor format for model consumption.

Parameters:

inputs (Any) – Raw input data (dict, tensor, or other formats)

Returns:

Dict of processed tensors ready for model input

Return type:

Dict[str, Tensor]

class world_models.DreamerOperator(image_size=64, action_dim=6)[source]#

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

process(inputs)[source]#

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

process(inputs)[source]#

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.IrisOperator(seq_length=512, vocab_size=32000)[source]#

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

process(inputs)[source]#

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.PlaNetOperator(state_dim=32, action_dim=4)[source]#

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

process(inputs)[source]#

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

world_models.get_operator(name, **kwargs)[source]#

Factory function to get inference operators by name.

Parameters:
  • name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’

  • **kwargs – Operator-specific configuration

Returns:

Configured OperatorABC instance

Example

>>> op = get_operator('dreamer', image_size=64, action_dim=6)
>>> processed = op.process({'image': image, 'action': action})
class world_models.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Predict scalar rewards from Dreamer latent belief and state vectors.

Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.

forward(belief, state)[source]#
class world_models.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Estimate scalar value from Dreamer latent belief and state vectors.

This MLP is trained on imagined returns and used for actor/value updates.

forward(belief, state)[source]#

User-facing convenience APIs for TorchWM.

The lower-level modules remain available for research workflows, but this module collects the common discovery and construction paths behind small, predictable factory functions.

class world_models.api.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing an environment backend available through make_env.

Parameters:
  • name (str)

  • factory_path (str)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

factory_path: str#

Alias for field number 1

description: str#

Alias for field number 2

aliases: tuple[str, ...]#

Alias for field number 3

class world_models.api.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing a model available through create_model().

Parameters:
  • name (str)

  • import_path (str)

  • config_path (str | None)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

import_path: str#

Alias for field number 1

config_path: str | None#

Alias for field number 2

description: str#

Alias for field number 3

aliases: tuple[str, ...]#

Alias for field number 4

world_models.api.create_config(model, **overrides)[source]#

Create the default config object for model and apply overrides.

Examples

>>> cfg = create_config("dreamer", env="walker-walk", seed=7)
>>> cfg.env
'walker-walk'
Parameters:
  • model (str)

  • overrides (Any)

Return type:

Any

world_models.api.create_model(model, config=None, **overrides)[source]#

Instantiate a model or agent from a simple string name.

config is optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.

Examples

>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000)
>>> genie = create_model("genie-small", image_size=32)
Parameters:
  • model (str)

  • config (Any | None)

  • overrides (Any)

Return type:

Any

world_models.api.get_env_backend_spec(name)[source]#

Return metadata for an environment backend name or alias.

Parameters:

name (str)

Return type:

EnvBackendSpec

world_models.api.get_model_spec(name)[source]#

Return metadata for a model name or alias.

Parameters:

name (str)

Return type:

ModelSpec

world_models.api.list_env_backends()[source]#

Return canonical backend names accepted by make_env().

Return type:

list[str]

world_models.api.list_envs(model=None)[source]#

List known environment ids, optionally filtered by model family.

Parameters:

model (str | None)

Return type:

list[str] | dict[str, list[str]]

world_models.api.list_models()[source]#

Return canonical model names accepted by create_model().

Return type:

list[str]

world_models.api.make_env(env_id, backend='auto', **kwargs)[source]#

Create an environment with a consistent TorchWM entry point.

Parameters:
  • env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.

  • backend (str) – One of list_env_backends(); "auto" tries TorchWM’s compatibility helper.

  • **kwargs (Any) – Backend-specific options.

Return type:

Any

Models sub-module - Core world model implementations.

Exported Components:
Agents (High-level training wrappers):
  • DreamerAgent: High-level Dreamer training API

  • JEPAAgent: JEPA agent for self-supervised learning

  • Planet: PlaNet planning agent

  • VisionTransformer: Vision Transformer for image encoding

  • ModularRSSM: Modular RSSM with swappable components

  • Genie: Generative Interactive Environment model

Core Models:
  • Dreamer: Core Dreamer implementation with RSSM, actor, critic

  • RSSM: Recurrent State-Space Model (Dreamer-style)

  • RecurrentStateSpaceModel: PlaNet-style RSSM

  • LatentActionModel: Latent action learning for Genie

  • DynamicsModel: Future frame prediction for Genie

Factory Functions:
  • create_genie, create_genie_small, create_genie_large

  • create_modular_rssm

  • create_latent_action_model, create_dynamics_model

Model catalog#

Core model families#

Key classes: Dreamer, DreamerAgent, RSSM, RecurrentStateSpaceModel, Planet, ModularRSSM, JEPAAgent, VisionTransformer, IRISAgent, IRISTransformer, IRISWorldModel, Genie, LatentActionModel, and DynamicsModel.

world_models.models.dreamer.get_available_memory()[source]#

Get available physical memory in bytes.

world_models.models.dreamer.make_env(args)[source]#

Construct a Dreamer-compatible environment from DreamerConfig options.

Supports DMC, Gym/Gymnasium, MuJoCo, Gymnasium Robotics, Brax, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit.

world_models.models.dreamer.preprocess_obs(obs)[source]#

Convert raw uint8 image observations to Dreamer float input space.

Images are scaled from [0, 255] to roughly [-0.5, 0.5], matching the normalization expected by Dreamer encoders.

class world_models.models.dreamer.Dreamer(args, obs_shape, action_size, device, restore=False)[source]#

Bases: object

Core Dreamer training system combining world model, actor, and value nets.

This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O.

world_model_loss(obs, acs, rews, nonterms)[source]#
actor_loss()[source]#
value_loss()[source]#
train_one_batch()[source]#
act_with_world_model(obs, prev_state, prev_action, explore=False)[source]#
act_and_collect_data(env, collect_steps)[source]#
evaluate(env, eval_episodes, render=False)[source]#
collect_random_episodes(env, seed_steps)[source]#
save(save_path)[source]#
restore_checkpoint(ckpt_path)[source]#
class world_models.models.dreamer.DreamerAgent(config=None, **kwargs)[source]#

Bases: object

High-level user API for running Dreamer experiments end to end.

It builds environments from config, initializes seeds and logging, instantiates Dreamer, and exposes simple train() / evaluate() methods.

train(total_steps=None)[source]#
evaluate()[source]#
class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer for latent dynamics learning.

The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:

  1. Deterministic State (h): A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.

  2. Stochastic State (s): A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).

The model operates in two modes:

  • Observe Mode: Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)

  • Imagine Mode: Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)

Architecture:

Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1} Process: GRU updates deterministic state, MLP computes stochastic prior/posterior Output: Updated state (h_t, s_t) and distributions

State Representation:
  • deter (h): GRU hidden state, captures sequential context

  • stoch (s): Stochastic latent, multi-modal uncertainty

  • mean/std: Parameters of the stochastic distribution

Usage with DreamerAgent:
rssm = RSSM(

action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation=’elu’

)

# Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed)

# Imagine future without observation prior = rssm.imagine_step(current_state, action)

Training:

The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to

capture environment dynamics

  • Reconstruction loss from decoder ensures state captures observation info

Reference:

Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603

init_state(batch_size, device)[source]#

Initialize RSSM state with zeros.

Parameters:
  • batch_size – Number of parallel sequences

  • device – torch device for tensors

Returns:

  • mean, std: Stochastic distribution parameters

  • stoch: Stochastic state sample

  • deter: Deterministic GRU hidden state

Return type:

Dictionary containing zero-initialized state components

get_dist(mean, std)[source]#

Create an Independent Normal distribution from mean and std.

Parameters:
  • mean – Location parameter

  • std – Scale parameter

Returns:

Independent Normal distribution with given parameters

observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#

Update state using actual observation (observe mode).

In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.

Parameters:
  • prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action – Previous action a_{t-1}, shape (B, action_size)

  • obs_embed – Observation embedding from encoder, shape (B, obs_embed_size)

  • nonterm – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

A tuple (posterior, prior) of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#

Predict next state without observation (imagine mode).

In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.

Parameters:
  • prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action – Previous action a_{t-1}, shape (B, action_size)

  • nonterm – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

  • deter: Predicted deterministic state

  • mean, std, stoch: Prior stochastic state distribution

Return type:

Dictionary with predicted state containing

get_prior(prev_state, prev_action, nonterm=1.0)[source]#

Compute prior distribution over stochastic state.

The prior represents the model’s belief about the stochastic state before observing the actual outcome.

Parameters:
  • prev_state – Previous state dictionary

  • prev_action – Previous action

  • nonterm – Termination mask

Returns:

Dictionary with prior state (no observation information)

get_posterior(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#

Compute posterior distribution over stochastic state.

The posterior incorporates observation information to produce a more accurate state estimate.

Parameters:
  • prev_state – Previous state dictionary

  • prev_action – Previous action

  • obs_embed – Observation embedding

  • nonterm – Termination mask

Returns:

Dictionary with posterior state (observation-informed). Note that the previous-state shape (B, ...) is preserved; the batch dimension is not flattened.

detach_state(state)[source]#

Detach state tensors from computation graph.

Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.

Parameters:

state – State dictionary with tensor values

Returns:

Detached state dictionary

seq_to_batch(state_dict)[source]#

Convert sequence state to batch format.

Parameters:

state_dict – Dictionary with sequence-dimension tensors (T, B, …)

Returns:

Dictionary with batch-dimension tensors (B*T, …)

observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#

Process a sequence of observations (observe mode rollout).

At each timestep we run observe_step once to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.

Parameters:
  • obs_embed – Observation embeddings, shape (T+1, B, obs_embed_size)

  • actions – Actions, shape (T, B, action_size)

  • nonterms – Non-termination flags, shape (T, B, 1)

  • init_state – Initial state dictionary

  • seq_len – Sequence length T

Returns:

Dictionary with prior states stacked along the time axis. posterior: Dictionary with posterior states stacked along the time

axis.

Return type:

prior

imagine_rollout(policy, init_state, horizon)[source]#

Generate imagined trajectory using policy (imagine mode rollout).

Parameters:
  • policy – Actor network that outputs actions from state features

  • init_state – Initial state dictionary

  • horizon – Number of steps to imagine

Returns:

Dictionary with imagined states for each step

forward(x, u)[source]#

Forward pass for training (computes sequence of states).

Parameters:
  • x – Observations, shape (B, T+1, C, H, W)

  • u – Actions, shape (B, T, action_size)

Returns:

List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)

Return type:

states

class world_models.models.rssm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#

Bases: Module

A Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.

get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#

Returns the initial posterior given the observation.

deterministic_state_fwd(h_t, s_t, a_t)[source]#
Deterministic transition update that accepts:
  • a_t shaped [B, action_size]

  • a_t shaped [action_size] (unbatched) -> expanded to [B, action_size]

  • a_t shaped [B] or scalar -> reshaped appropriately

Ensures a_t is 2D and matches batch dimension of h_t before concatenation.

state_prior(h_t, sample=False)[source]#

Returns the prior distribution over the latent state given the deterministic state

state_posterior(h_t, e_t, sample=False)[source]#

Returns the state prior given the deterministic state and obs

pred_reward(h_t, s_t)[source]#
rollout_prior(act, h_t, s_t)[source]#
forward(x, u)[source]#

Forward through the RSSM for a batch of sequences. Inputs:

x: Tensor [B, T+1, C, H, W] (observations including initial frame) u: Tensor [B, T, action_size] (actions for T steps)

Returns:

list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]

Return type:

states

class world_models.models.planet.Planet(env, bit_depth=5, device=None, state_size=200, latent_size=30, embedding_size=1024, memory_size=100, policy_cfg=None, headless=False, max_episode_steps=None, action_repeats=1, results_dir=None)[source]#

Bases: object

High-level Planet wrapper.

Usage example:

from world_models.models.planet import Planet p = Planet(env=’CartPole-v1’, bit_depth=5) p.train(epochs=50)

warmup(n_episodes=1, random_policy=True)[source]#

Collect n_episodes of rollouts into memory (used as warmup).

train(epochs=100, steps_per_epoch=150, batch_size=32, H=50, beta=1.0, save_every=25, record_grads=False, results_dir=None, scheduler_type='step', scheduler_kwargs=None)[source]#

High-level training loop. Delegates single-step training to the existing train function.

Parameters:
  • scheduler_type (str) – Type of scheduler to use (“step”, “cosine”, “exponential”, “plateau”, None)

  • scheduler_kwargs (dict) – Additional arguments for the scheduler

Modular RSSM with swappable encoder/decoder/backbone components.

This module provides a flexible architecture for world model research, allowing researchers to easily swap different encoder, decoder, and backbone implementations for ablations and experimentation.

class world_models.models.modular_rssm.EncoderBase(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for observation encoders.

Parameters:
  • args (Any)

  • kwargs (Any)

abstractmethod forward(obs)[source]#

Encode observations to embeddings.

Parameters:

obs (Tensor)

Return type:

Tensor

embed_size: int#
get_embed_size()[source]#

Return the embedding size. Override in subclasses.

Return type:

int

class world_models.models.modular_rssm.DecoderBase(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for observation decoders.

Parameters:
  • args (Any)

  • kwargs (Any)

abstractmethod forward(features)[source]#

Decode latent features to observation distributions.

Parameters:

features (Tensor)

Return type:

Any

class world_models.models.modular_rssm.BackboneBase(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for recurrent dynamics backbones.

Parameters:
  • args (Any)

  • kwargs (Any)

abstractmethod forward(state, action, obs_embed=None, nonterm=1.0)[source]#

Process one step of dynamics. Returns (prior, posterior).

Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

abstractmethod init_state(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

stoch_size: int#
deter_size: int#
class world_models.models.modular_rssm.ConvEncoder(input_shape, embed_size, activation='elu', depth=32)[source]#

Bases: EncoderBase

Convolutional encoder from Dreamer (image observations).

Parameters:
  • input_shape (Tuple[int, int, int])

  • embed_size (int)

  • activation (str)

  • depth (int)

forward(obs)[source]#
Parameters:

obs (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.MLPEncoder(input_dim, embed_size, hidden_sizes=[256, 256], activation='elu')[source]#

Bases: EncoderBase

MLP encoder for state-based observations.

Parameters:
  • input_dim (int)

  • embed_size (int)

  • hidden_sizes (List[int])

  • activation (str)

forward(obs)[source]#
Parameters:

obs (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.ViTEncoder(input_shape, embed_size, patch_size=8, depth=6, num_heads=8, mlp_ratio=4.0, activation='gelu')[source]#

Bases: EncoderBase

Vision Transformer encoder for image observations.

Parameters:
  • input_shape (Tuple[int, int, int])

  • embed_size (int)

  • patch_size (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • activation (str)

forward(obs)[source]#
Parameters:

obs (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.TransformerBlock(embed_size, num_heads, mlp_ratio, activation)[source]#

Bases: Module

Transformer block for ViT encoder.

Parameters:
  • embed_size (int)

  • num_heads (int)

  • mlp_ratio (float)

  • activation (str)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.ConvDecoder(stoch_size, deter_size, output_shape, activation='elu', depth=32)[source]#

Bases: DecoderBase

Convolutional decoder for image observations.

Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (Tuple[int, int, int])

  • activation (str)

  • depth (int)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Any

class world_models.models.modular_rssm.MLPDecoder(stoch_size, deter_size, output_dim, hidden_sizes=[256, 256], activation='elu', dist='normal')[source]#

Bases: DecoderBase

MLP decoder for state-based observations.

Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_dim (int)

  • hidden_sizes (List[int])

  • activation (str)

  • dist (str)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Any

class world_models.models.modular_rssm.GRUBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#

Bases: BackboneBase

GRU-based recurrent dynamics backbone (standard RSSM).

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • hidden_size (int)

  • embed_size (int)

  • activation (str)

property embedding_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

class world_models.models.modular_rssm.LSTMBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#

Bases: BackboneBase

LSTM-based recurrent dynamics backbone.

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • hidden_size (int)

  • embed_size (int)

  • activation (str)

property embedding_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

class world_models.models.modular_rssm.TransformerBackbone(action_size, stoch_size, deter_size, embed_size, num_heads=4, num_layers=2, activation='gelu')[source]#

Bases: BackboneBase

Transformer-based dynamics backbone for long-range dependencies.

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • embed_size (int)

  • num_heads (int)

  • num_layers (int)

  • activation (str)

property embedding_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

class world_models.models.modular_rssm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#

Bases: Module

Modular RSSM with swappable encoder, decoder, and backbone.

This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer

Example

>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024)
>>> decoder = ConvDecoder(32, 200, (3, 64, 64))
>>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024)
>>> rssm = ModularRSSM(encoder, decoder, backbone)
Parameters:
property stoch_size: int#
property deter_size: int#
property embed_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

get_dist(mean, std)[source]#
Parameters:
  • mean (Tensor)

  • std (Tensor)

Return type:

Distribution

observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • obs (Tensor)

  • nonterm (Any)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • nonterm (Any)

Return type:

Dict[str, Tensor]

observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
Parameters:
  • obs (Tensor)

  • actions (Tensor)

  • nonterms (Tensor)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_rollout(actor, prev_state, horizon)[source]#
Parameters:
  • actor (Module)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Dict[str, Tensor]

decode_observation(features)[source]#
Parameters:

features (Tensor)

decode_reward(features)[source]#
Parameters:

features (Tensor)

detach_state(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

seq_to_batch(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

world_models.models.modular_rssm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#

Factory function to create a modular RSSM with specified components.

Parameters:
  • encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)

  • decoder_type (str) – Type of decoder (“conv”, “mlp”)

  • backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)

  • obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state

  • action_size (int) – Action space dimension

  • stoch_size (int) – Stochastic latent dimension

  • deter_size (int) – Deterministic hidden dimension

  • embed_size (int) – Encoder embedding dimension

  • hidden_size (int) – Hidden layer dimension

  • activation (str) – Activation function name

Returns:

Configured ModularRSSM instance

Return type:

ModularRSSM

class world_models.models.jepa_agent.JEPAAgent(config=None, **kwargs)[source]#

Bases: object

Convenience interface for configuring and launching JEPA training runs.

Accepts a JEPAConfig plus keyword overrides, prepares output folders, and delegates execution to the JEPA training entrypoint.

Parameters:

config (JEPAConfig | None)

train()[source]#
world_models.models.vit.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate fixed 2D sine/cosine positional embeddings on a square patch grid.

Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding.

world_models.models.vit.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)[source]#

Build 2D sine/cosine embeddings from precomputed meshgrid coordinates.

The final embedding concatenates independent encodings for vertical and horizontal coordinates.

world_models.models.vit.get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate 1D sine/cosine positional embeddings for integer positions.

Useful for sequence-style positional encoding and as a building block for 2D embedding construction.

world_models.models.vit.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#

Generate 1D sine/cosine positional embeddings from explicit positions.

Positions are projected onto a log-frequency basis and encoded with sine and cosine components.

world_models.models.vit.drop_path(x, drop_prob=0.0, training=False)[source]#

Apply stochastic depth (DropPath) regularization to residual branches.

Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude.

Parameters:
  • drop_prob (float)

  • training (bool)

class world_models.models.vit.DropPath(drop_prob=None)[source]#

Bases: Module

Module wrapper around the functional drop_path stochastic depth utility.

forward(x)[source]#
class world_models.models.vit.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#

Bases: Module

Two-layer feed-forward network used inside transformer blocks.

Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern.

forward(x)[source]#
class world_models.models.vit.Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Multi-head self-attention block for token sequences.

Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout.

forward(x)[source]#
class world_models.models.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Transformer encoder block combining attention and MLP residual branches.

Each branch uses pre-normalization and optional stochastic depth.

forward(x)[source]#
class world_models.models.vit.PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768)[source]#

Bases: Module

Image to Patch Embedding

forward(x)[source]#
class world_models.models.vit.ConvEmbed(channels, strides, img_size=224, in_chans=3, batch_norm=True)[source]#

Bases: Module

3x3 Convolution stems for ViT following ViTC models

forward(x)[source]#
class world_models.models.vit.VisionTransformerPredictor(num_patches, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

fix_init_weight()[source]#
forward(x, masks_x, masks)[source]#
class world_models.models.vit.VisionTransformer(img_size=[224], patch_size=16, in_chans=3, embed_dim=768, predictor_embed_dim=384, depth=12, predictor_depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

fix_init_weight()[source]#
forward(x, masks=None)[source]#
interpolate_pos_encoding(x, pos_embed)[source]#
world_models.models.vit.vit_predictor(**kwargs)[source]#

Factory for a JEPA predictor transformer with sensible defaults.

world_models.models.vit.vit_tiny(patch_size=16, **kwargs)[source]#

Factory for a tiny Vision Transformer encoder backbone.

world_models.models.vit.vit_small(patch_size=16, **kwargs)[source]#

Factory for a small Vision Transformer encoder backbone.

world_models.models.vit.vit_base(patch_size=16, **kwargs)[source]#

Factory for a base Vision Transformer encoder backbone.

world_models.models.vit.vit_large(patch_size=16, **kwargs)[source]#

Factory for a large Vision Transformer encoder backbone.

world_models.models.vit.vit_huge(patch_size=16, **kwargs)[source]#

Factory for a huge Vision Transformer encoder backbone.

world_models.models.vit.vit_giant(patch_size=16, **kwargs)[source]#

Factory for a giant Vision Transformer encoder backbone.

world_models.models.iris_agent.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#

Compute λ-return target for value function training.

Parameters:
  • rewards (Tensor) – Rewards (B, T)

  • values (Tensor) – Value estimates (B, T+1)

  • discounts (Tensor) – Discount factors (B, T)

  • lambda_coef (float) – Lambda parameter for bootstrapping

Returns:

λ-return targets (B, T)

Return type:

lambda_returns

class world_models.models.iris_agent.IRISAgent(config, action_size, device)[source]#

Bases: Module

Complete IRIS Agent with world model and policy.

Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning

Parameters:
  • config (IRISConfig)

  • action_size (int)

  • device (device)

forward_actor_critic(frames, hidden=None)[source]#

Forward pass through actor-critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state

Returns:

(B, T, action_size) values: (B, T) hidden_state: (h, c)

Return type:

action_logits

act(frame, epsilon=0.0, temperature=1.0)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • epsilon (float) – Random action probability

  • temperature (float) – Action distribution temperature

Returns:

Selected actions (B,)

Return type:

actions

imagine_rollout(initial_frame, horizon=20)[source]#

Generate imagined trajectories using world model.

Parameters:
  • initial_frame (Tensor) – Starting frame (B, C, H, W)

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined rollout data

Return type:

trajectory

update_autoencoder(frames)[source]#

Update discrete autoencoder.

Parameters:

frames (Tensor) – Training frames (B, C, H, W)

Returns:

Dictionary of loss values

Return type:

losses

update_transformer(frames, actions, rewards, terminals)[source]#

Update transformer world model.

Parameters:
  • frames (Tensor) – Frame sequence

  • actions (Tensor) – Actions taken

  • rewards (Tensor) – Rewards received

  • terminals (Tensor) – Terminal flags

Returns:

Dictionary of loss values

Return type:

losses

update_actor_critic(imagined_trajectory)[source]#

Update actor-critic in imagination.

Parameters:

imagined_trajectory (dict) – Dictionary from imagine_rollout

Returns:

Dictionary of loss values

Return type:

losses

save(path)[source]#

Save agent state.

Parameters:

path (str)

load(path)[source]#

Load agent state.

Parameters:

path (str)

class world_models.models.iris_transformer.IRISTransformer(vocab_size=512, tokens_per_frame=16, action_size=18, embed_dim=256, num_layers=10, num_heads=4, dropout=0.1)[source]#

Bases: Module

Autoregressive Transformer for world modeling.

Models the dynamics of the environment by predicting: - Next frame tokens (transition model) - Rewards - Episode termination

The Transformer operates on sequences of interleaved frame tokens and actions.

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • action_size (int)

  • embed_dim (int)

  • num_layers (int)

  • num_heads (int)

  • dropout (float)

forward(tokens, actions, mask=None)[source]#

Forward pass through the Transformer world model.

Parameters:
  • tokens (Tensor) – Frame tokens (B, T, K) where T is timesteps

  • actions (Tensor) – Actions (B, T)

  • mask (Tensor | None) – Optional attention mask

Returns:

Next token predictions (B, T, K, vocab_size) rewards: Predicted rewards (B, T) terminations: Predicted terminations (B, T, 2)

Return type:

token_logits

predict_next_tokens(tokens, actions)[source]#

Predict the next frame tokens autoregressively.

Used during imagination rollouts.

Parameters:
  • tokens (Tensor) – Current frame tokens (B, K)

  • actions (Tensor) – Actions taken (B,)

Returns:

Next frame token predictions (B, K, vocab_size) action_hidden: Hidden states for reward prediction (B, embed_dim)

Return type:

token_logits

sample_next_tokens(tokens, actions, temperature=1.0)[source]#

Sample next tokens from the distribution.

Parameters:
  • tokens (Tensor) – Current frame tokens (B, K)

  • actions (Tensor) – Actions taken (B,)

  • temperature (float) – Sampling temperature (higher = more random)

Returns:

Sampled token indices (B, K) log_probs: Log probabilities of sampled tokens (B, K)

Return type:

sampled_tokens

class world_models.models.iris_transformer.IRISWorldModel(encoder, decoder, transformer)[source]#

Bases: Module

Complete IRIS World Model combining autoencoder and transformer.

This is the core component that learns environment dynamics entirely in the “imaginary” latent space.

Parameters:
forward(observations, actions)[source]#

Full world model forward pass.

Parameters:
  • observations (Tensor) – Image sequence (B, T+1, C, H, W)

  • actions (Tensor) – Actions (B, T)

Returns:

Dictionary with predicted tokens, rewards, terminations losses: Dictionary with loss components

Return type:

predictions

imagine(initial_tokens, policy, horizon=20, temperature=1.0)[source]#

Generate imagined trajectories.

Parameters:
  • initial_tokens (Tensor) – Initial frame tokens (B, K)

  • policy (Module) – Policy network to sample actions

  • horizon (int) – Number of steps to imagine

  • temperature (float) – Sampling temperature for token prediction

Returns:

Dictionary with imagined trajectories

Return type:

imagined

class world_models.models.genie.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Genie: Generative Interactive Environment.

A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions

Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).

Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)

The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse

At inference, latent actions are stopgrad’d when passed to dynamics model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_decoder_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • latent_action_depth (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

forward(video, mask_prob=0.5, training_phase='all')[source]#

Full forward pass through all components.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing losses and predictions

Return type:

Dict[str, Tensor]

training_step(video, mask_prob=0.5, training_phase='all')[source]#

Single training step computing all losses.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing all losses for backpropagation

Return type:

Dict[str, Tensor]

encode_video(video)[source]#

Encode video to discrete tokens.

Parameters:

video (Tensor) – (B, C, T, H, W)

Returns:

(B, T, H*W)

Return type:

video_tokens

infer_actions(frames)[source]#

Infer latent actions from a sequence of frames.

Parameters:

frames (Tensor) – (B, C, T, H, W) video frames

Returns:

(B, T-1) inferred latent action indices

Return type:

latent_actions

generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#

Generate video frames given a prompt frame and actions.

Parameters:
  • prompt_frame (Tensor) – (B, C, H, W) initial frame

  • num_frames (int) – Total number of frames to generate

  • actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random

  • use_maskgit (bool) – Whether to use MaskGIT sampling

Returns:

(B, C, num_frames, H, W)

Return type:

generated_video

play(current_frame, action, current_frames=None)[source]#

Play step - generate next frame given current frame and action.

Parameters:
  • current_frame (Tensor) – (B, C, H, W) current frame

  • action (Tensor) – (B,) latent action indices

  • current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame

Returns:

(B, C, H, W)

Return type:

next_frame

get_num_parameters()[source]#

Return total number of parameters.

Return type:

int

world_models.models.genie.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.models.genie.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Create a smaller Genie model for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.models.genie.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#

Create the full 11B parameter Genie model (approximate).

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

class world_models.models.latent_action_model.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Latent Action Model (LAM) for unsupervised action learning.

Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.

Based on Genie paper - learns actions without action labels from Internet videos.

Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

encode(x_prev, x_next)[source]#

Encode frames to latent actions.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)

Return type:

latent_actions

decode(x_prev, z_q)[source]#

Decode latent actions to predict next frame.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W) - will mask all but first

  • z_q (Tensor) – Quantized action embeddings (B, T-1, embedding_dim)

Returns:

(B, C, H, W)

Return type:

predicted_next_frame

forward(x_prev, x_next)[source]#

Full forward pass: encode to get actions, decode to reconstruct.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Dictionary with losses and outputs

Return type:

Dict[str, Tensor]

world_models.models.latent_action_model.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

LatentActionModel

class world_models.models.dynamics_model.MaskGITSampler(num_steps=25, temperature=2.0, mask_schedule='cosine')[source]#

Bases: object

MaskGIT sampling for token-based video generation.

Uses iterative refinement with a mask schedule to progressively reveal tokens during generation.

Parameters:
  • num_steps (int)

  • temperature (float)

  • mask_schedule (str)

get_mask_prob(step)[source]#

Get mask probability for given step.

Parameters:

step (int)

Return type:

float

sample(logits, mask, step)[source]#

Sample tokens from logits with mask.

Parameters:
  • logits (Tensor) – (B, T, vocab_size)

  • mask (Tensor) – (B, T) - 1 for tokens to predict, 0 for fixed tokens

  • step (int) – Current step in refinement

Returns:

(B, T) new_mask: (B, T)

Return type:

sampled_tokens

class world_models.models.dynamics_model.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: Module

Dynamics Model for action-controllable video generation.

A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.

Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

forward(video_tokens, actions, mask_prob=0.0)[source]#

Forward pass for training.

Parameters:
  • video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T

  • actions (Tensor) – (B, T) - latent action indices for frames 1 to T

  • mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)

Returns:

(B, T, H*W, vocab_size)

Return type:

logits

sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#

Sample future frames using MaskGIT.

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • sampler (MaskGITSampler | None) – MaskGIT sampler instance

Returns:

(B, num_frames, N)

Return type:

generated_tokens

autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#

Simple autoregressive sampling (token by token).

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • temperature (float) – Sampling temperature

Returns:

(B, num_frames, N)

Return type:

generated_tokens

world_models.models.dynamics_model.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#

Factory function to create a Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

Return type:

DynamicsModel

Diffusion and DIAMOND components#

Key classes: DDPM, DiT, DiffusionUNet, EDMPreconditioner, EulerSampler, RewardTerminationModel, and ActorCriticNetwork.

Diffusion sub-module - Diffusion model components for world models.

Exported Components:
  • DiT: Diffusion Transformer model

  • PatchEmbed: Image patch embedding

  • PatchUnEmbed: Patch unembedding (decode tokens to image)

  • DDPM: Denoising Diffusion Probabilistic Model implementation

  • ActorCriticNetwork: DIAMOND actor-critic network

  • RewardTerminationModel: Reward/termination prediction model

  • sinusoidal_time_embedding: Time embedding for diffusion models

class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end, device)[source]#

Bases: object

Utility class implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

q_sample(x_start, t, noise=None)[source]#
p_sample(model, x_t, t)[source]#
sample(model, n, img_size, channels)[source]#
world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.

Math:

embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)

Parameters:
  • timesteps – Tensor of integer timesteps, shape (B,) or (B, 1)

  • dim – Embedding dimension (must be even)

Returns:

Tensor of shape (B, dim) with sinusoidal embeddings

Usage with DiT:

t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)

# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization

class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.

Process:
  1. Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches

  2. Each patch is projected to embed_dim via linear layer (Conv2d)

  3. Learnable positional embeddings are added for spatial information

Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)

Parameters:
  • img_size – Image size (assumes square), e.g., 32 for CIFAR

  • patch_size – Size of each patch (typically 4, 8, or 16)

  • in_channels – Number of input channels (3 for RGB)

  • embed_dim – Output dimension for each patch token

Usage with DiT:

patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4

forward(x)[source]#
class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

forward(x)[source]#
class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]#

Bases: Module

Conditioned transformer block used inside the DiT backbone.

Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.

forward(x, t_emb)[source]#
class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

forward(x, t)[source]#
classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
class world_models.models.diffusion.diamond_diffusion.AdaptiveGroupNorm(num_groups, num_channels, cond_dim)[source]#

Bases: Module

Adaptive Group Normalization that conditions on actions and diffusion time.

Parameters:
  • num_groups (int)

  • num_channels (int)

  • cond_dim (int)

forward(x, cond)[source]#
Parameters:
  • x (Tensor) – Input tensor [B, C, H, W]

  • cond (Tensor) – Conditioning tensor [B, cond_dim]

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.ResBlock(in_channels, out_channels, cond_dim, dropout=0.0)[source]#

Bases: Module

Residual block with adaptive group normalization.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • dropout (float)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.AttentionBlock(channels, cond_dim)[source]#

Bases: Module

Self-attention block for U-Net.

Parameters:
  • channels (int)

  • cond_dim (int)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.TimestepEmbedding(dim, freq_dim=256)[source]#

Bases: Module

Sinusoidal timestep embedding.

Parameters:
  • dim (int)

  • freq_dim (int)

forward(t)[source]#
Parameters:

t (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.DownBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#

Bases: Module

Downsampling block for U-Net encoder.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • num_res_blocks (int)

  • attention (bool)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.UpBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#

Bases: Module

Upsampling block for U-Net decoder.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • num_res_blocks (int)

  • attention (bool)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.DiffusionUNet(obs_channels=3, num_conditioning_frames=4, base_channels=64, channel_multipliers=(1, 1, 1, 1), num_res_blocks=2, cond_dim=256, action_dim=18)[source]#

Bases: Module

U-Net architecture for EDM diffusion world model. Uses frame stacking for observation conditioning and adaptive group norm for action conditioning.

Parameters:
  • obs_channels (int)

  • num_conditioning_frames (int)

  • base_channels (int)

  • channel_multipliers (Tuple[int, ...])

  • num_res_blocks (int)

  • cond_dim (int)

  • action_dim (int)

forward(x, t, obs_history, actions)[source]#

Forward pass of the diffusion model.

Parameters:
  • x (Tensor) – Noised observation at timestep t [B, C, H, W]

  • t (Tensor) – Diffusion timestep [B]

  • obs_history (Tensor) – Past observations for conditioning [B, L, C, H, W]

  • actions (Tensor) – Past actions [B, L]

Returns:

Predicted clean observation [B, C, H, W]

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.EDMPreconditioner(sigma_data=0.5, p_mean=-0.4, p_std=1.2)[source]#

Bases: object

EDM preconditioner following Karras et al. (2022).

Parameters:
  • sigma_data (float)

  • p_mean (float)

  • p_std (float)

get_preconditioners(sigma)[source]#

Compute EDM preconditioners for given noise levels.

Returns:

Dictionary with c_skip, c_out, c_in, c_noise

Parameters:

sigma (Tensor)

Return type:

dict

sample_noise_level(batch_size, device)[source]#

Sample noise level from log-normal distribution.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tensor

denoise(model, x, sigma, **kwargs)[source]#

Apply EDM denoising with preconditioners.

Parameters:
  • model – Diffusion model

  • x (Tensor) – Noised input [B, C, H, W]

  • sigma (Tensor) – Noise level [B]

  • **kwargs – Additional conditioning (obs_history, actions)

Returns:

Denoised prediction [B, C, H, W]

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.EulerSampler(sigma_min=0.002, sigma_max=80.0, rho=7, num_steps=3, edm_precond=None)[source]#

Bases: object

Euler method sampler for reverse diffusion.

Parameters:
  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • num_steps (int)

  • edm_precond (EDMPreconditioner | None)

sample(model, shape, device, obs_history=None, actions=None)[source]#

Generate samples using Euler method.

Parameters:
  • model (Module) – Diffusion model

  • shape (Tuple[int, ...]) – Output shape [B, C, H, W]

  • device (device) – Device to run on

  • obs_history (Tensor | None) – Conditioning observations [B, L, C, H, W]

  • actions (Tensor | None) – Conditioning actions [B, L]

Returns:

Generated samples [B, C, H, W]

Return type:

Tensor

class world_models.models.diffusion.reward_termination.ConvBlock(in_channels, out_channels, cond_dim, stride=2)[source]#

Bases: Module

Convolutional block with adaptive group normalization.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • stride (int)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.reward_termination.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#

Bases: Module

Reward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.

Parameters:
  • obs_channels (int) – Number of observation channels (3 for RGB)

  • action_dim (int) – Number of possible actions

  • channels (Tuple[int, ...]) – List of channel sizes for conv blocks

  • lstm_dim (int) – LSTM hidden dimension

  • cond_dim (int) – Conditioning dimension for adaptive norm

forward(obs, actions, hidden_state=None)[source]#

Forward pass of reward/termination model.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • actions (Tensor) – Actions [B, T]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states

Return type:

reward_logits

predict(obs, actions, hidden_state=None)[source]#

Predict reward and termination for a single step.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • actions (Tensor) – Single action [B]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states

Return type:

reward

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class world_models.models.diffusion.reward_termination.RewardTerminationLoss[source]#

Bases: Module

Loss function for reward and termination prediction.

forward(reward_logits, termination_logits, rewards, terminated)[source]#

Compute loss for reward and termination predictions.

Parameters:
  • reward_logits (Tensor) – [B, T, 3]

  • termination_logits (Tensor) – [B, T, 2]

  • rewards (Tensor) – Rewards as class indices [B, T] (values -1, 0, 1 mapped to 0, 1, 2)

  • terminated (Tensor) – Termination flags [B, T]

Returns:

total_loss, reward_loss, termination_loss

Return type:

Tuple[Tensor, Tensor, Tensor]

class world_models.models.diffusion.actor_critic.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#

Bases: Module

Actor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.

Parameters:
  • obs_channels (int)

  • action_dim (int)

  • channels (Tuple[int, ...])

  • lstm_dim (int)

forward(obs, hidden_state=None)[source]#

Forward pass of actor-critic network.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)

Return type:

policy_logits

get_action(obs, hidden_state=None, deterministic=False)[source]#

Get action from a single observation.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

  • deterministic (bool) – If True, take argmax; else sample

Returns:

Selected action [B] hidden_state: (h, c)

Return type:

action

get_actions(obs, hidden_state=None, deterministic=False)[source]#

Batched version of get_action.

Parameters:
  • obs (Tensor) – Tensor of shape [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size

  • deterministic (bool) – If True, take argmax; else sample from policy

Returns:

LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple

Return type:

actions

get_value(obs, hidden_state=None)[source]#

Get value for a single observation.

Parameters:
  • obs (Tensor)

  • hidden_state (Tuple[Tensor, Tensor] | None)

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor] | None]

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_hidden_size()[source]#

Get LSTM hidden size.

Return type:

int

class world_models.models.diffusion.actor_critic.RLLoss(discount_factor=0.985, lambda_returns=0.95, entropy_weight=0.001)[source]#

Bases: Module

RL loss functions for DIAMOND. Implements REINFORCE with value baseline and λ-returns.

Parameters:
  • discount_factor (float)

  • lambda_returns (float)

  • entropy_weight (float)

compute_lambda_returns(rewards, values, dones)[source]#

Compute λ-returns.

Parameters:
  • rewards (Tensor) – [B, T]

  • values (Tensor) – [B, T+1]

  • dones (Tensor) – [B, T]

Returns:

[B, T]

Return type:

lambda_returns

policy_loss(policy_logits, actions, lambda_returns, values)[source]#

Compute policy loss with REINFORCE and entropy regularization.

Parameters:
  • policy_logits (Tensor) – [B, T, A]

  • actions (Tensor) – [B, T]

  • lambda_returns (Tensor) – [B, T]

  • values (Tensor) – [B, T+1]

Returns:

scalar

Return type:

policy_loss

value_loss(values, lambda_returns)[source]#

Compute value loss (MSE between value and lambda returns).

Parameters:
  • values (Tensor)

  • lambda_returns (Tensor)

Return type:

Tensor

Vision, tokenization, and layers#

Key classes: ConvEncoder, ConvDecoder, DenseDecoder, ActionDecoder, CNNEncoder, CNNDecoder, IRISEncoder, IRISDecoder, DiscreteAutoencoder, VectorQuantizer, VectorQuantizerEMA, VideoTokenizer, MultiHeadSelfAttention, and STTransformer.

class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).

Architecture:

Input: (B, C, H, W) raw images, values in [-0.5, 0.5] Process: 4 convolutional layers with stride 2, halving spatial dimensions Output: (B, embed_size) compact representation

The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.

Usage with Dreamer:
encoder = ConvEncoder(

input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation=’relu’ # Activation function

) obs_embedding = encoder(observation) # (B, 256)

Parameters:
  • input_shape – Tuple (C, H, W) for input images, typically (3, 64, 64)

  • embed_size – Output embedding dimension, typically 256 or 1024

  • activation – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)

  • depth – Base channel depth for first layer (default 32)

forward(inputs)[source]#
class world_models.vision.dreamer_decoder.TanhBijector[source]#

Bases: Transform

Bijective tanh transform for squashing Gaussian distributions to [-1, 1].

This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:

  1. Bijective mapping: tanh is invertible (with atanh as inverse)

  2. Stable log-det Jacobian: Computable for gradient-based training

  3. Clipped actions: During inference, actions are naturally bounded

Math:

Forward: y = tanh(x) Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y)) Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))

Usage with Dreamer ActionDecoder:
dist = TransformedDistribution(

Normal(mean, std), TanhBijector()

) action = dist.sample() # Bounded to [-1, 1]

Reference:

Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.

property sign#
atanh(x)[source]#
log_abs_det_jacobian(x, y)[source]#
class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional decoder for reconstructing observations from latent states.

Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.

Architecture:

Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter) Process: Dense projection + 4 transposed convolutions (upsampling 2x each) Output: Independent Normal distribution over observation pixels

The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.

Output Distribution:

Returns torch.distributions.Independent(Normal(mean, std), len(shape)) This allows computing log_prob(observation) for reconstruction loss.

Usage in Dreamer world model:
decoder = ConvDecoder(

stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation=’relu’

) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)

Training:

The reconstruction loss is: -log_prob(observation) This encourages the RSSM to learn states that capture observation information.

forward(features)[source]#
class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.

Architecture:

Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter) Process: MLP with configurable layers and hidden units Output: Predicted quantity with distribution (normal, binary, or raw)

Supports three output types:
  • ‘normal’: Gaussian distribution for regression (rewards, values)

  • ‘binary’: Bernoulli distribution for binary classification (discount)

  • ‘none’: Raw tensor for non-probabilistic outputs

Usage:
reward_decoder = DenseDecoder(

stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’normal’

) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)

For discount prediction (binary):
discount_decoder = DenseDecoder(

stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’binary’ # Bernoulli for P(continue)

)

forward(features)[source]#
class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

property name#
mean()[source]#
mode()[source]#
entropy()[source]#
sample()[source]#
class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

forward(features, deter=False)[source]#
add_exploration(action, action_noise=0.3)[source]#
class world_models.vision.planet_encoder.CNNEncoder(embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) encoder for processing image inputs.

forward(observation)[source]#
class world_models.vision.planet_decoder.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) decoder for reconstructing image outputs.

forward(latent, state)[source]#
class world_models.vision.iris_encoder.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Encoder for IRIS discrete autoencoder.

Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.

Architecture:
  • 4 convolutional layers with residual blocks

  • Self-attention at 8x8 and 16x16 resolutions

  • Vector quantization to produce discrete tokens

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • in_channels (int)

  • base_channels (int)

  • num_residual_blocks (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Encode images to discrete tokens.

Parameters:

x (Tensor) – Input images (B, C, H, W) - should be 64x64

Returns:

Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

encode_to_indices(x)[source]#

Encode directly to token indices (for world model).

Parameters:

x (Tensor)

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode token indices to embeddings (for decoder).

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.vision.iris_encoder.ResidualBlock(channels)[source]#

Bases: Module

Residual block for encoder.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_encoder.SelfAttentionBlock(channels)[source]#

Bases: Module

Self-attention block for encoder.

Applies spatial self-attention to capture long-range dependencies.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Decoder for IRIS discrete autoencoder.

Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • base_channels (int)

  • out_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(z)[source]#

Decode tokens to images.

Parameters:

z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)

Returns:

Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)

Return type:

reconstructed

decode_from_embeddings(z_flat)[source]#

Decode flattened token embeddings to images.

Parameters:

z_flat (Tensor) – Flattened tokens (B, H*W, C) or (B, C, H, W)

Returns:

Reconstructed images

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode discrete token indices into images.

Parameters:

indices (Tensor) – Tensor of shape (B, H, W) or (B, H*W) containing integer token indices in the range [0, vocab_size).

Returns:

Reconstructed images (B, C, H, W)

Return type:

Tensor

class world_models.vision.iris_decoder.UpsampleBlock(in_channels, mid_channels, out_channels)[source]#

Bases: Module

Upsampling block with optional residual connection.

Parameters:
  • in_channels (int)

  • mid_channels (int)

  • out_channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.ResidualBlock(channels)[source]#

Bases: Module

Residual block for decoder.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.DiscreteAutoencoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, base_channels=64, frame_shape=(3, 64, 64))[source]#

Bases: Module

Complete Discrete Autoencoder combining encoder and decoder.

Used for training the VQVAE component of IRIS.

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • base_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Full encode-decode forward pass.

Parameters:

x (Tensor) – Input images (B, C, H, W)

Returns:

Reconstructed images indices: Token indices (B, H’, W’) loss_dict: Dictionary with loss components

Return type:

reconstruction

encode(x)[source]#

Encode to token indices.

Parameters:

x (Tensor)

Return type:

Tensor

decode(indices)[source]#

Decode token indices to images.

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.vision.vq_layer.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#

Bases: Module

Vector Quantizer for discrete autoencoder.

Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)

Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

forward(z)[source]#

Quantize the input latents.

Parameters:

z (Tensor) – Input tensor of shape (B, C, H, W) or (B, C)

Returns:

Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices back to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.vision.vq_layer.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#

Bases: Module

Vector Quantizer with Exponential Moving Average updates.

Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • ema_decay (float)

  • epsilon (float)

forward(z)[source]#

Quantize with EMA updates.

Parameters:

z (Tensor)

Return type:

Tuple[Tensor, Tensor, dict[str, Tensor]]

decode_indices(indices)[source]#

Decode token indices to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.vision.video_tokenizer.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#

Bases: Module

Video Tokenizer using VQ-VAE with Spatiotemporal Transformer.

This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.

The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.

Architecture:
  1. Patch Embedding: Convert (B, C, T, H, W) video to patch tokens

  2. Encoder ST-Transformer: Process spatial-temporal patches

  3. Vector Quantization: Discretize continuous embeddings to codebook entries

  4. Decoder ST-Transformer: Reconstruct video from quantized tokens

  5. Patch Unembedding: Convert tokens back to video frames

Key Features:
  • Causal processing: Each frame’s encoding only uses previous frames

  • Discrete tokens: Enables autoregressive prediction with latent actions

  • Memory efficient: Uses ST-Transformer instead of full ViT to reduce O(n²) complexity

Usage with Genie:
tokenizer = VideoTokenizer(

num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32

) reconstructed, indices, loss_dict = tokenizer(video_frames)

# For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)

Training:

The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings (encourages learning useful codes) - Commitment loss: Penalizes encoder outputs drifting from codebook

Reference:

Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • use_ema (bool)

  • ema_decay (float)

encode(x)[source]#

Encode video to discrete tokens.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices to embeddings for video frames.

Parameters:

indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’*W’

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim)

Return type:

z_q

decode(z_q)[source]#

Decode discrete tokens to video frames.

Parameters:

z_q (Tensor) – Quantized embeddings (B, T, H’, W’, embedding_dim)

Returns:

Reconstructed video (B, C, T, H, W)

Return type:

Tensor

forward(x)[source]#

Full forward pass with VQ-VAE objective.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Reconstructed video (B, C, T, H, W) indices: Token indices (B, T, H’, W’) loss_dict: Dictionary containing loss components

Return type:

reconstructed

world_models.vision.video_tokenizer.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#

Factory function to create a Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

Return type:

VideoTokenizer

Blocks sub-module - Transformer blocks and attention mechanisms.

Exported Components:
Transformers:
  • STTransformer: Spatiotemporal Transformer for video processing

  • STSpatialAttention: Spatial attention layer

  • STTemporalAttention: Temporal attention layer

  • STBlock: Combined spatiotemporal transformer block

Attention:
  • MultiHeadSelfAttention: Multi-head self-attention

  • MultiHeadAttention: Generic multi-head attention (alias)

  • Attention: Attention mechanism

Normalization:
  • RMSNorm: Root Mean Square Layer Normalization

  • AdaLNNormalization: Adaptive Layer Normalization

class world_models.blocks.mhsa.MultiHeadSelfAttention(d, n_heads=2)[source]#

Bases: Module

Multi-head scaled dot-product self-attention over sequence tokens.

This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.

forward(x)[source]#
class world_models.blocks.st_transformer.STSpatialAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Spatial attention layer for spatiotemporal transformer.

Processes video tokens by attending over spatial positions (H*W) within each time step independently. Captures within-frame spatial relationships.

Input: (B, T, N, C) - B batches, T time steps, N spatial positions (H*W), C channels Output: (B, T, N, C) - Same shape, spatially attended features

Architecture:

QKV projection: Linear(dim, dim*3) Reshape to multi-head attention format Attention: softmax(Q @ K^T / sqrt(d_k)) @ V Output projection

Usage in ST-Transformer:

Applied to video tokens of shape (B, T, N, C) to capture within-frame spatial structure (e.g., object positions).

Parameters:
  • dim (int)

  • num_heads (int)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • attn_drop (float)

  • proj_drop (float)

forward(x)[source]#
Parameters:

x (Tensor) – (B, T, N, C) where T is temporal dim, N is spatial dim (H*W)

Returns:

(B, T, N, C)

Return type:

Tensor

class world_models.blocks.st_transformer.STTemporalAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Temporal attention layer with causal masking for spatiotemporal transformer.

Processes video tokens by attending over time steps (T) across all spatial positions. Uses causal masking to ensure each frame only attends to previous frames (important for autoregressive video generation).

Input: (B, T, N, C) - B batches, T time steps, N spatial positions, C channels Output: (B, T, N, C) - Same shape, temporally attended features

Key Feature: Causal masking
  • Frame t can only attend to frames 0…t-1

  • Prevents information leakage from future frames

  • Essential for autoregressive video generation models

Usage in Genie VideoTokenizer:

Applied after STSpatialAttention to model temporal dynamics. The causal mask ensures generation is autoregressive.

Parameters:
  • dim (int)

  • num_heads (int)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • attn_drop (float)

  • proj_drop (float)

forward(x, causal=True)[source]#
Parameters:
  • x (Tensor) – (B, T, N, C) where T is temporal dim, N is spatial dim (H*W)

  • causal (bool) – whether to apply causal masking

Returns:

(B, T, N, C)

Return type:

Tensor

class world_models.blocks.st_transformer.STMLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#

Bases: Module

MLP for ST-Transformer block.

Parameters:
  • in_features (int)

  • hidden_features (int | None)

  • out_features (int | None)

  • act_layer (type[Module])

  • drop (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.blocks.st_transformer.STTransformerBlock(dim, num_heads=8, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Combined spatiotemporal transformer block with interleaved attention.

A single block applies:
  1. Spatial attention (within each time frame)

  2. Temporal attention (across frames with causal mask)

  3. MLP projection

The order is: x -> + SpatialAttn -> + TemporalAttn -> + MLP -> x

This interleaved design captures both spatial structure and temporal dynamics efficiently, used in Genie’s VideoTokenizer and DynamicsModel.

Parameters:
  • dim (int) – Feature dimension (must match patch embedding dimension)

  • num_heads (int) – Number of attention heads

  • mlp_ratio (float) – MLP hidden dim = dim * mlp_ratio

  • drop (float) – Dropout rates

  • attn_drop (float) – Dropout rates

  • drop_path (float) – Stochastic depth rate for drop path regularization

  • norm_layer (type[Module]) – Normalization layer class (default: nn.LayerNorm)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • act_layer (type[Module])

Usage in Genie:

# VideoTokenizer encoder (12 layers) encoder = STTransformer(

num_frames=16, num_patches_per_frame=256, # 16x16 for 64x64 images with patch_size=4 dim=512, depth=12, num_heads=16

) encoded = encoder(tokens) # (B, T*N, C)

# Dynamics model decoder (24 layers) decoder = STTransformer(

num_frames=16, num_patches_per_frame=256, dim=1024, depth=24, num_heads=16

) decoded = decoder(tokens)

forward(x)[source]#
Parameters:

x (Tensor) – (B, T, N, C) or (B, T*H*W, C)

Returns:

Same shape as input

Return type:

Tensor

class world_models.blocks.st_transformer.DropPath(drop_prob=0.0)[source]#

Bases: Module

Drop paths (Stochastic Depth) per sample.

Parameters:

drop_prob (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.blocks.st_transformer.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Spatiotemporal Transformer for video modeling.

Contains L spatiotemporal blocks with interleaved spatial and temporal attention.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

forward(x)[source]#
Parameters:

x (Tensor) – (B, T*N, C) where T is num_frames, N is num_patches_per_frame

Returns:

(B, T*N, C)

Return type:

Tensor

world_models.blocks.st_transformer.create_st_transformer(num_frames=16, patch_size=4, img_size=64, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Factory function to create an ST-Transformer.

Parameters:
  • num_frames (int)

  • patch_size (int)

  • img_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

Return type:

STTransformer

Configuration objects#

Lazy config exports.

Configuration modules can have optional training dependencies, so the package initializer avoids importing every config eagerly.

class world_models.configs.DreamerConfig[source]#

Bases: object

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

class world_models.configs.JEPAConfig[source]#

Bases: object

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: object

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.configs.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)

class world_models.configs.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: object

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#
class world_models.configs.IRISConfig[source]#

Bases: object

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
get_autoencoder_config()[source]#
get_transformer_config()[source]#
get_rl_config()[source]#
class world_models.configs.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: object

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.configs.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#

Bases: object

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
class world_models.configs.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.configs.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: object

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class world_models.configs.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: object

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class world_models.configs.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.configs.dreamer_config.DreamerConfig[source]#

Bases: object

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

class world_models.configs.jepa_config.JEPAConfig[source]#

Bases: object

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

class world_models.configs.iris_config.IRISConfig[source]#

Bases: object

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
get_autoencoder_config()[source]#
get_transformer_config()[source]#
get_rl_config()[source]#
class world_models.configs.genie_config.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: object

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.configs.genie_config.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#

Bases: object

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
class world_models.configs.genie_config.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.configs.genie_config.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: object

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class world_models.configs.genie_config.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: object

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class world_models.configs.genie_config.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: object

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.configs.dit_config.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: object

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.configs.dit_config.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)

world_models.configs.diamond_config.get_default_device()[source]#
Return type:

str

class world_models.configs.diamond_config.ModelPreset(diffusion_channels, diffusion_res_blocks, diffusion_cond_dim, reward_channels, reward_lstm_dim, actor_channels, actor_lstm_dim)[source]#

Bases: object

Model architecture preset for different hardware tiers.

Parameters:
  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • reward_channels (List[int])

  • reward_lstm_dim (int)

  • actor_channels (List[int])

  • actor_lstm_dim (int)

diffusion_channels: List[int]#
diffusion_res_blocks: int#
diffusion_cond_dim: int#
reward_channels: List[int]#
reward_lstm_dim: int#
actor_channels: List[int]#
actor_lstm_dim: int#
class world_models.configs.diamond_config.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: object

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#

Training entry points#

world_models.training.train_jepa.main(args, resume_preempt=False)[source]#

Run JEPA training using a nested config dict or JEPAConfig instance.

This entrypoint initializes distributed context, data pipeline, masking, models, optimizers/schedulers, checkpointing, and the full epoch loop.

world_models.training.train_jepa.sweep_train()[source]#

Function for WandB sweep agent.

class world_models.training.train_iris.IRISTrainer(game='ALE/Pong-v5', device='cuda', seed=42, config=None)[source]#

Bases: object

Training loop for IRIS on Atari 100k benchmark.

Parameters:
  • game (str)

  • device (str)

  • seed (int)

  • config (IRISConfig | None)

preprocess_frame(frame)[source]#

Preprocess frame: resize to 64x64 and normalize.

Parameters:

frame (ndarray)

Return type:

ndarray

collect_experience(num_steps, epsilon=0.01)[source]#

Collect experience from environment.

Parameters:
  • num_steps (int) – Number of steps to collect

  • epsilon (float) – Random action probability

Returns:

Mean episode return

Return type:

float

train_epoch(epoch)[source]#

Train for one epoch.

Parameters:

epoch (int) – Current epoch number

Returns:

Dictionary of metrics

Return type:

dict

get_epsilon(epoch)[source]#

Get exploration epsilon with decay.

Parameters:

epoch (int)

Return type:

float

evaluate(num_episodes=100, render=False)[source]#

Evaluate agent performance.

Parameters:
  • num_episodes (int) – Number of evaluation episodes

  • render (bool) – If True, also return video frames and per-step latent vectors

Returns:

dict with evaluation metrics If render is True: tuple (episode_returns_array, videos_list, latents_array)

Return type:

If render is False (default)

train(total_epochs=None, eval_interval=50, save_dir='checkpoints/iris')[source]#

Full training loop.

Parameters:
  • total_epochs (int | None) – Total training epochs

  • eval_interval (int) – Evaluate every N epochs

  • save_dir (str) – Directory to save checkpoints

world_models.training.train_iris.main()[source]#

Run IRIS training on a single Atari game.

class world_models.training.train_genie.GenieConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, tokenizer_encoder_depth=12, tokenizer_decoder_depth=20, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_encoder_depth=20, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: object

Configuration for Genie training.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 512#
tokenizer_decoder_dim: int = 1024#
tokenizer_encoder_depth: int = 12#
tokenizer_decoder_depth: int = 20#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 1024#
action_encoder_depth: int = 20#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.training.train_genie.VideoDataset(video_paths, num_frames=16, image_size=64)[source]#

Bases: Dataset

Dataset for video data.

Parameters:
  • video_paths (list)

  • num_frames (int)

  • image_size (int)

class world_models.training.train_genie.GenieTrainer(model, config, device=None)[source]#

Bases: object

Trainer for Genie model.

Parameters:
  • model (Module)

  • config (GenieConfig)

  • device (device | None)

train_step(batch)[source]#

Single training step.

Parameters:

batch (Tensor) – (B, C, T, H, W) video batch

Returns:

Dictionary of losses

Return type:

Dict[str, Tensor]

validate(val_batch)[source]#

Validation step.

Parameters:

val_batch (Tensor) – (B, C, T, H, W) validation video batch

Returns:

Dictionary of validation metrics

Return type:

Dict[str, Tensor]

train(train_dataloader, val_dataloader=None, num_steps=None, log_interval=100, val_interval=1000)[source]#

Full training loop.

Parameters:
  • train_dataloader (DataLoader) – Training data loader

  • val_dataloader (DataLoader | None) – Validation data loader (optional)

  • num_steps (int | None) – Number of training steps (uses config.max_steps if None)

  • log_interval (int) – Logging frequency

  • val_interval (int) – Validation frequency

save_checkpoint(path)[source]#

Save model checkpoint.

Parameters:

path (str)

load_checkpoint(path)[source]#

Load model checkpoint.

Parameters:

path (str)

world_models.training.train_genie.create_genie_trainer(config=None, device=None)[source]#

Factory function to create Genie trainer and model.

Parameters:
Return type:

Tuple[GenieTrainer, Module]

world_models.training.train_planet.train(memory, rssm, optimizer, device, N=32, H=50, beta=1.0, grads=False)[source]#

Training implementation as indicated in: Learning Latent Dynamics for Planning from Pixels arXiv:1811.04551

(a.) The Standard Variational Bound Method

using only single step predictions.

world_models.training.train_planet.main()[source]#

Example PlaNet/RSSM training script with rollout collection and evaluation.

Builds environment/model/policy objects, iteratively trains on replayed episodes, and periodically saves videos and checkpoints.

world_models.training.train_rssm.train_rssm(memory, model, optimizer, record_grads=True)[source]#

Train an RSSM on replayed trajectories for one optimization phase.

Samples batches from memory, computes reconstruction and KL objectives across rollout steps, and returns aggregated loss metrics.

world_models.training.train_rssm.evaluate(memory, model, path, eps)[source]#

Run one RSSM reconstruction/prediction evaluation and save visual outputs.

Decodes priors/posteriors for a sampled sequence and writes frame grids for qualitative inspection.

world_models.training.train_rssm.main()[source]#

Standalone training loop for RSSM with generated replay fallback support.

Initializes environment/policy/memory, trains over episodes, logs metrics, and periodically evaluates and checkpoints the model.

class world_models.training.train_diamond.DiamondAgent(config)[source]#

Bases: object

DIAMOND: DIffusion As a Model Of eNvironment Dreams

RL agent trained entirely within a diffusion world model.

Parameters:

config (DiamondConfig)

train()[source]#

Main training loop following Algorithm 1.

evaluate(num_episodes=1)[source]#

Evaluate the agent.

Parameters:

num_episodes (int)

Return type:

float

save_checkpoint(path=None)[source]#

Save model checkpoint.

Parameters:

path (str | PathLike | None) – Optional path where to write the checkpoint. If path is None or a bare filename, the file is written into checkpoints/diamond/<filename>. If path contains a directory component or is an absolute/relative path, it is used directly. When path is None, the legacy behavior is preserved and the checkpoint is written to checkpoints/diamond/checkpoint.pt.

load_checkpoint(path=None)[source]#

Load model checkpoint.

Parameters:

path (str | None) – Optional path to checkpoint. If None, the default checkpoints/diamond/checkpoint.pt is loaded. If a bare filename is provided, we try checkpoints/diamond/<filename>; if a path with directory components is provided we use it directly.

world_models.training.train_diamond.train_diamond(game, seed=0, preset=None, device='cpu')[source]#

Train DIAMOND on a specific game.

Parameters:
  • game (str)

  • seed (int)

  • preset (str | None)

  • device (str)

class world_models.training.rl_harness.ActorCritic(obs_shape, action_dim, hidden_dim=256)[source]#

Bases: Module

Simple actor-critic network for RL harness.

Parameters:
  • obs_shape (tuple)

  • action_dim (int)

  • hidden_dim (int)

forward(obs)[source]#

Forward pass through CNN, then actor and critic heads.

Parameters:

obs (Tensor)

Return type:

tuple[Tensor, Tensor]

get_action(obs)[source]#

Sample action from policy.

Parameters:

obs (Tensor)

Return type:

tuple[Tensor, Tensor, Tensor]

class world_models.training.rl_harness.PPOTrainer(vec_env, device='cpu', lr=0.0003, gamma=0.99, gae_lambda=0.95, clip_ratio=0.2, num_epochs=10, batch_size=64, max_grad_norm=0.5, entropy_coeff=0.01, value_coeff=0.5)[source]#

Bases: object

Simple PPO trainer for testing vectorized environments.

Parameters:
  • vec_env (TorchVectorizedEnv)

  • device (str)

  • lr (float)

  • gamma (float)

  • gae_lambda (float)

  • clip_ratio (float)

  • num_epochs (int)

  • batch_size (int)

  • max_grad_norm (float)

  • entropy_coeff (float)

  • value_coeff (float)

collect_trajectories(num_steps)[source]#

Collect trajectories using the vectorized environment.

Parameters:

num_steps (int)

Return type:

Dict[str, Tensor]

compute_gae(rewards, values, dones)[source]#

Compute Generalized Advantage Estimation.

Parameters:
  • rewards (Tensor)

  • values (Tensor)

  • dones (Tensor)

Return type:

Tensor

train_step(trajectories)[source]#

Perform one training step using PPO.

Parameters:

trajectories (Dict[str, Tensor])

train(total_timesteps, log_interval=1000)[source]#

Main training loop.

Parameters:
  • total_timesteps (int)

  • log_interval (int)

world_models.training.rl_harness.create_rl_harness_example()[source]#

Example function to create and run the RL harness. Usage: Call this with your environment factory.

Memory, controllers, and inference operators#

class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer with image observations and transitions.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.

Key Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores raw uint8 images to save memory

  • Samples sequences (not single transitions) for temporal modeling

  • Validates sampled sequences don’t span episode boundaries

Memory Layout:
  • observations: (capacity, C, H, W) uint8 images

  • actions: (capacity, action_dim) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)

Sampling Process:
  1. Random start index (avoiding episode boundaries)

  2. Collect sequence of length seq_len with wraparound

  3. Validate no terminal in middle of sequence

  4. Return batch of sequences

Usage with Dreamer:
buffer = ReplayBuffer(

size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch

)

# Add transitions during interaction buffer.add(obs, action, reward, done)

# Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample() # Shapes: (seq_len, batch, C, H, W), (seq_len, batch, action_dim), etc.

Memory Efficiency:
  • Uses uint8 for images (1 byte per pixel vs 4 for float32)

  • Sequences share observations (overlapping windows)

  • Configurable capacity based on available system memory

Note

The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.

Parameters:
  • size (int)

  • obs_shape (Tuple[int, ...])

  • action_size (int)

  • seq_len (int)

  • batch_size (int)

add(obs, ac, rew, done)[source]#

Add a transition to the buffer.

Parameters:
  • obs (dict) – Observation dict with ‘image’ key containing the observation

  • ac (ndarray) – Action taken, shape (action_size,)

  • rew (float) – Reward received, scalar

  • done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise

Return type:

None

sample()[source]#

Sample a batch of sequences for training.

Returns:

(observations, actions, rewards, terminals)
  • observations: (seq_len, batch, C, H, W)

  • actions: (seq_len, batch, action_dim)

  • rewards: (seq_len, batch)

  • terminals: (seq_len, batch)

Return type:

tuple

class world_models.memory.dreamer_memory.Memory(capacity=10000)[source]#

Bases: object

Simple deque-based memory for storing transitions.

Used by PlaNet for online planning. Stores recent transitions and provides random sampling for policy updates.

Parameters:

capacity (int) – Maximum number of transitions to store

Usage:

memory = Memory(capacity=10000) memory.append(obs, action, reward, done, info) batch = random.sample(memory, batch_size=32)

append(*args)[source]#

Append a transition to memory.

Parameters:

*args – Variable length argument list containing transition data. Typically (observation, action, reward, done, info).

Return type:

None

sample(batch_size)[source]#

Sample random batch of transitions from memory.

Parameters:

batch_size (int) – Number of transitions to sample.

Returns:

List of sampled transitions.

Return type:

list

class world_models.memory.dreamer_memory.Episode(observation, action=None, reward=None, terminal=None, info=None)[source]#

Bases: object

Stores a single episode for PlaNet’s imagination and planning.

An episode is a sequence of (observation, action, reward) tuples collected during environment interaction. Episodes are used for computing returns and training value functions.

Parameters:
  • obs – Initial observation

  • action – First action (optional)

  • reward – Initial reward (optional)

  • info – Additional info dict (optional)

Usage:

episode = Episode(obs, info=info) episode.append(action, obs, reward, done, info) episodes = [episode for _ in range(num_episodes)]

# Use with Planet agent for planning imag_state, imag_reward, imag_action = planet.imagine(episodes)

append(action, observation, reward, terminal, info=None)[source]#
Return type:

None

class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode.

Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.

x#

Observations collected during the episode.

Type:

list or np.ndarray

u#

Actions taken.

Type:

list or np.ndarray

r#

Rewards received.

Type:

list or np.ndarray

t#

Terminal flags (0.0 = continue, 1.0 = terminal).

Type:

list or np.ndarray

info#

Additional episode metadata.

Type:

dict

Parameters:

postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.

Example

>>> episode = Episode()
>>> episode.append(obs, action, reward, False)
>>> episode.append(obs, action, reward, True)
>>> episode.terminate(final_obs)
>>> print(episode.x.shape)  # Now a numpy array
property size#
append(obs, act, reward, terminal)[source]#
terminate(obs)[source]#
class world_models.memory.planet_memory.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.

Features:
  • Stores complete episodes as lists of transitions

  • Samples contiguous sub-sequences for sequence models

  • Supports time-major formatting (time-first) for RNN input

  • Memory usage estimation to prevent OOM errors

Parameters:

size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).

episodes#

Collection of Episode objects.

Type:

deque

eps_lengths#

Length of each episode.

Type:

deque

size#

Total number of transitions across all episodes.

Type:

property

Example

>>> memory = Memory(size=100)
>>> memory.append([episode1, episode2])
>>> batch, lengths = memory.sample(batch_size=32, tracelen=50)
property size#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

sample(batch_size, tracelen=1, time_first=False)[source]#

Sample random sub-sequences from stored episodes.

Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.

Parameters:
  • batch_size (int) – Number of sequences to sample.

  • tracelen (int) – Length of each sequence (default: 1).

  • time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).

Returns:

(observations, actions, rewards, terminals, lengths)
  • observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)

  • actions: (batch, tracelen, action_dim) or (tracelen, batch, …)

  • rewards: (batch, tracelen) or (tracelen, batch)

  • terminals: (batch, tracelen) or (tracelen, batch)

  • lengths: (batch,) original episode lengths for each sample

Return type:

tuple

Raises:
  • ValueError – If memory is empty or no episodes meet minimum length.

  • MemoryError – If estimated memory usage exceeds 200 MiB threshold.

class world_models.memory.iris_memory.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#

Bases: object

Replay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.

Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores uint8 images for memory efficiency

  • Samples sequences with validation to avoid episode boundaries

  • Supports sequence sampling for temporal learning

Memory Layout:
  • observations: (capacity, C, H, W) uint8

  • actions: (capacity, action_size) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32

Parameters:
  • size (int) – Maximum number of transitions to store.

  • obs_shape (tuple) – Shape of observations as (C, H, W).

  • action_size (int) – Dimension of actions.

  • seq_len (int) – Length of sequences to sample (default: 20).

  • batch_size (int) – Number of sequences per batch (default: 64).

size#

Buffer capacity.

Type:

int

obs_shape#

Observation shape.

Type:

tuple

action_size#

Action dimension.

Type:

int

seq_len#

Sequence length.

Type:

int

batch_size#

Batch size.

Type:

int

steps#

Total transitions added.

Type:

int

episodes#

Number of episode terminations observed.

Type:

int

add(obs, action, reward, terminal)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray) – Observation array with shape (C, H, W).

  • action (ndarray) – Action array with shape (action_size,).

  • reward (float) – Scalar reward value.

  • terminal (bool) – Boolean indicating if episode terminated.

sample_sequence(seq_len=None)[source]#

Sample a batch of sequences for world model training.

Returns:

(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)

Return type:

observations

Parameters:

seq_len (int | None)

sample_single()[source]#

Sample a single transition for online updates.

Return type:

Tuple[ndarray, ndarray, float, float]

property buffer_capacity#

Returns the total capacity of the buffer.

class world_models.memory.iris_memory.IRISOnPolicyBuffer(max_steps=1000)[source]#

Bases: object

On-policy buffer for collecting trajectories during environment interaction.

Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.

Useful for:
  • Collecting complete episode trajectories

  • Storing data before batch processing

  • Temporary storage during environment interaction

Parameters:

max_steps (int) – Maximum number of steps to store (default: 1000).

max_steps#

Maximum buffer capacity.

Type:

int

observations#

List of observations.

Type:

list

actions#

List of actions.

Type:

list

rewards#

List of rewards.

Type:

list

terminals#

List of terminal flags.

Type:

list

add(obs, action, reward, terminal)[source]#
Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

clear()[source]#
get_arrays()[source]#
class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller using Cross-Entropy Method (CEM) with RSSM.

Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.

Algorithm:
  1. Initialize Gaussian distribution over action sequences

  2. Sample N candidate action sequences

  3. Rollout each sequence in RSSM latent space

  4. Score by predicted cumulative rewards

  5. Keep top K candidates, fit Gaussian to them

  6. Repeat for T iterations

  7. Execute first action from best sequence

Why latent space planning?
  • Images are high-dimensional; latent states are compact

  • Enables thousands of rollouts in parallel

  • Dynamics model is more accurate in latent space

Parameters:
  • model – RSSM instance for latent dynamics

  • planning_horizon – Number of future steps to plan (H)

  • num_candidates – Number of action sequences to sample (N)

  • num_iterations – CEM refinement iterations (T)

  • top_candidates – Number of best candidates to keep (K)

  • device – torch device

Usage with Planet agent:
policy = RSSMPolicy(

model=rssm, planning_horizon=12, num_candidates=1000, num_iterations=8, top_candidates=100, device=’cuda’

)

policy.reset() action = policy.poll(observation) # (1, action_dim)

# For continuous control: next_obs, reward, done, info = env.step(action)

Comparison with Dreamer:
  • RSSMPolicy: Online planning, chooses actions by optimization at each step

  • DreamerActor: Train actor network to predict actions from states

  • Dreamer is more sample-efficient for complex tasks; CEM is more flexible

reset()[source]#
poll(observation, explore=False)[source]#
class world_models.controller.iris_policy.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Actor network for IRIS (Imagined Rollouts with Implicit Successor) policy.

Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.

Architecture:
  • CNN: Extracts features from input frames (3x64x64 -> 512)

  • LSTM: Processes temporal sequences with configurable layers

  • Linear: Maps hidden states to action logits

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

action_size#

Number of discrete actions.

Type:

int

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

forward(frames, hidden_state=None, burn_in_frames=None)[source]#

Forward pass through actor.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state

  • burn_in_frames (Tensor | None) – Frames to use for initializing hidden state

Returns:

Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple

Return type:

action_logits

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_action(frame, temperature=1.0, deterministic=False)[source]#

Get action from a single frame.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • temperature (float) – Softmax temperature (higher = more random)

  • deterministic (bool) – If True, return argmax; else sample

Returns:

Selected action indices (B,)

Return type:

action

class world_models.controller.iris_policy.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Critic network for IRIS value estimation.

Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.

Architecture:
  • CNN: Shared feature extractor with actor (3x64x64 -> 512)

  • LSTM: Temporal processing with same architecture as actor

  • Linear: Maps hidden states to scalar values

Parameters:
  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Returns:

Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.

Return type:

values

Parameters:
  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None)[source]#

Forward pass through critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple

Returns:

Value estimates (B, T) hidden_state: Updated (h, c) tuple

Return type:

values

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class world_models.controller.iris_policy.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#

Bases: Module

CNN feature extractor shared between actor and critic networks.

Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.

Architecture:
  • Conv layers: 32 -> 64 -> 128 -> 256 channels

  • Each conv has stride=2 for spatial downsampling

  • Final linear layer projects to desired output dimension

Parameters:
  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

  • output_size (int) – Size of output feature vector (default: 512).

frame_shape#

Input frame shape.

Type:

tuple

output_size#

Output feature dimension.

Type:

int

Returns:

Feature vectors with shape (B, output_size).

Return type:

features

Parameters:
  • frame_shape (Tuple[int, int, int])

  • output_size (int)

forward(x)[source]#

Extract features from frames.

Parameters:

x (Tensor) – Frames (B, C, H, W)

Returns:

Feature vectors (B, output_size)

Return type:

features

class world_models.controller.iris_policy.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Combined policy module for IRIS (Imagined Rollouts with Implicit Successor).

Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

actor#

The actor network for action selection.

Type:

IRISActor

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Example

>>> policy = IRISPolicy(
...     action_size=18,
...     hidden_size=512,
...     num_layers=4,
...     frame_shape=(3, 64, 64)
... )
>>> action = policy.act(frame, temperature=1.0, deterministic=False)
forward(frames)[source]#

Get action logits from frames.

Parameters:

frames (Tensor)

Return type:

Tensor

act(frame, temperature=1.0, deterministic=False)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor)

  • temperature (float)

  • deterministic (bool)

Return type:

Tensor

init_hidden(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

class world_models.controller.rollout_generator.RolloutGenerator(env, device, policy=None, max_episode_steps=None, episode_gen=None, name=None, enable_streaming_video=False, streaming_video_path=None, streaming_video_fps=20, streaming_video_format='mp4')[source]#

Bases: object

Rollout generator class.

rollout_once(random_policy=False, explore=False)[source]#

Performs a single rollout of an environment given a policy and returns and episode instance.

Return type:

Episode

rollout_n(n=1, random_policy=False)[source]#

Performs n rollouts.

Return type:

list[Episode]

rollout_eval_n(n)[source]#
rollout_eval(collect_latents=False)[source]#
class world_models.inference.operators.OperatorABC[source]

Bases: ABC

Abstract base class for operators that preprocess inputs for inference pipelines.

abstractmethod process(inputs)[source]

Process raw inputs into standardized tensor format for model consumption.

Parameters:

inputs (Any) – Raw input data (dict, tensor, or other formats)

Returns:

Dict of processed tensors ready for model input

Return type:

Dict[str, Tensor]

class world_models.inference.operators.DreamerOperator(image_size=64, action_dim=6)[source]

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

process(inputs)[source]

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

process(inputs)[source]

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.IrisOperator(seq_length=512, vocab_size=32000)[source]

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

process(inputs)[source]

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.PlaNetOperator(state_dim=32, action_dim=4)[source]

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

process(inputs)[source]

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

world_models.inference.operators.get_operator(name, **kwargs)[source]

Factory function to get inference operators by name.

Parameters:
  • name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’

  • **kwargs – Operator-specific configuration

Returns:

Configured OperatorABC instance

Example

>>> op = get_operator('dreamer', image_size=64, action_dim=6)
>>> processed = op.process({'image': image, 'action': action})
class world_models.inference.operators.base.OperatorABC[source]#

Bases: ABC

Abstract base class for operators that preprocess inputs for inference pipelines.

abstractmethod process(inputs)[source]#

Process raw inputs into standardized tensor format for model consumption.

Parameters:

inputs (Any) – Raw input data (dict, tensor, or other formats)

Returns:

Dict of processed tensors ready for model input

Return type:

Dict[str, Tensor]

class world_models.inference.operators.dreamer_operator.DreamerOperator(image_size=64, action_dim=6)[source]#

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

process(inputs)[source]#

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.planet_operator.PlaNetOperator(state_dim=32, action_dim=4)[source]#

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

process(inputs)[source]#

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.iris_operator.IrisOperator(seq_length=512, vocab_size=32000)[source]#

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

process(inputs)[source]#

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.jepa_operator.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

process(inputs)[source]#

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

Datasets, environments, and transforms#

Environment adapters#

The environment APIs below mirror the dedicated environment guide pages: DMC, Gym/Gymnasium, Atari/ALE, MuJoCo, Unity ML-Agents, and vectorization utilities. DIAMOND-style Atari support is intentionally not listed as an environment adapter because it is Atari preprocessing rather than a separate environment family.

world_models.envs.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#

Create any Atari environment from Arcadic Learning Environment (ALE).

Parameters:
  • env_id (str) – The id of the Atari environment to create.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The created Atari environment.

Return type:

gym.Env

world_models.envs.list_available_atari_envs()[source]#

Get a list of all available Atari environments in Arcadic Learning Environment (ALE).

Returns:

List of available Atari environment IDs.

Return type:

list[str]

world_models.envs.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#

Create vectorized Atari environments using ALE’s native AtariVectorEnv.

Parameters:
  • game (str) – The name of the Atari game (e.g., “pong”, “breakout”).

  • num_envs (int) – Number of parallel environments.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • seed (Optional[int]) – Random seed for reproducibility.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The vectorized Atari environment.

Return type:

AtariVectorEnv

class world_models.envs.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#

Bases: object

Native MuJoCo environment adapter for pixel-based world-model training.

The adapter uses the low-level mujoco Python package directly: models are compiled from MJCF XML strings/files or MJB binaries via mujoco.MjModel; simulation state lives in mujoco.MjData; actions are written to data.ctrl; and images are produced with mujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract: {"image": uint8[C, H, W]}.

Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply reward_fn and terminal_fn callbacks. By default, rewards are 0.0 and episodes terminate only through external wrappers such as TimeLimit.

Parameters:
  • xml_path (str | Path | None)

  • xml_string (str | None)

  • binary_path (str | Path | None)

  • assets (dict[str, bytes] | None)

  • seed (int)

  • size (tuple[int, int])

  • camera (str | int | None)

  • reward_fn (RewardFn | None)

  • terminal_fn (TerminalFn | None)

  • frame_skip (int)

  • reset_noise_scale (float)

  • default_control_range (tuple[float, float])

property observation_space#
property action_space#
reset(seed=None)[source]#
Parameters:

seed (int | None)

step(action)[source]#
render()[source]#
close()[source]#
world_models.envs.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create one MuJoCo image environment factory for tasks and MJCF/MJB models.

Parameters:
  • model (str | Path | None) – Either a Gymnasium MuJoCo task id such as "Humanoid-v4", an MJCF XML path/string, or an MJB binary path.

  • backend (str) – "auto" infers native vs Gymnasium task mode. Use "native" for MJCF/MJB, "gymnasium" for task ids, or "robotics" for Gymnasium Robotics registrations.

  • seed (int) – Seed forwarded to the image wrapper.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make in task-id mode. Extra **kwargs are also forwarded there.

  • **kwargs – Native MuJoCoImageEnv options for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.

Returns:

A TorchWM image environment returning {"image": uint8[C, H, W]}.

world_models.envs.make_mujoco_env_from_config(args, size)[source]#

Build a MuJoCo image environment from a DreamerConfig-like object.

Parameters:

size (tuple[int, int])

world_models.envs.list_gymnasium_robotics_envs()[source]#

List all Gymnasium Robotics ids registered by the installed package.

Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.

Return type:

list[str]

world_models.envs.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create a TorchWM image wrapper for a Gymnasium Robotics environment.

Parameters:
  • env (str) – Any environment id registered by gymnasium-robotics.

  • seed (int) – Seed forwarded to GymImageEnv.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode forwarded to gymnasium.make.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make.

  • **kwargs – Additional keyword arguments forwarded to gymnasium.make.

Returns:

A GymImageEnv that emits {"image": uint8[C, H, W]} observations.

world_models.envs.register_gymnasium_robotics_envs()[source]#

Import Gymnasium Robotics so its environments are registered with Gymnasium.

Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external gymnasium-robotics package. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely on gymnasium.register_envs; this helper supports both paths.

class world_models.envs.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#

Bases: object

Gym-like environment wrapper that always returns image observations.

This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.

Features:
  • Supports environment IDs (string) and pre-built environment objects.

  • Synthesizes RGB images from vector observations for pixel-based training.

  • Exposes continuous action spaces mapped to [-1, 1] range.

  • Converts discrete actions to one-hot vectors.

  • Returns observations as dict {“image”: (C, H, W)} with uint8 values.

Parameters:
  • env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • seed (int) – Random seed for environment reset (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • render_mode (str) – Render mode for environment (default: “rgb_array”).

observation_space#

Dict space with “image” key containing (C, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode (default: 1000).

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.make_gym_env(env, **kwargs)[source]#

Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.

Parameters:
  • env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • **kwargs – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)

Returns:

A wrapper that always returns image observations in the

format {“image”: (C, H, W)} suitable for pixel-based world models.

Return type:

GymImageEnv

class world_models.envs.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#

Bases: object

Gym-like wrapper for Unity ML-Agents environments.

Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.

Features:
  • Supports single-agent control with continuous action spaces.

  • Returns observations as {“image”: (C, H, W)} with uint8 values.

  • Normalizes actions to [-1, 1] range.

  • Includes rendered frames in observations for visual policies.

Parameters:
  • file_name (str) – Path to the Unity environment binary.

  • behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.

  • seed (int) – Random seed for environment (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • worker_id (int) – Worker ID for multi-environment setup (default: 0).

  • base_port (int) – Base port for Unity environment communication (default: 5005).

  • no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).

  • time_scale (float) – Simulation time scale multiplier (default: 20.0).

  • quality_level (int) – Graphics quality level 0-5 (default: 1).

  • max_episode_steps (int) – Maximum steps per episode (default: 1000).

observation_space#

Dict space with “image” key containing (3, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode.

Raises:
  • ValueError – If no behaviors found or action space is not continuous.

  • RuntimeError – If no agents available after reset.

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.make_unity_mlagents_env(env_id=None, **kwargs)[source]#

Create a Unity ML-Agents environment wrapper.

Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.

Parameters:
  • **kwargs – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).

  • env_id (str | None)

Returns:

A Gym-compatible wrapper for Unity environments.

Return type:

UnityMLAgentsEnv

class world_models.envs.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#

Bases: object

Gym-style adapter for DeepMind Control Suite tasks.

The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.

Features:
  • Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)

  • Automatically handles special cases like “cup” -> “ball_in_cup”

  • Renders RGB images at configurable resolution

  • Returns observations as dict with both state vectors and images

Parameters:
  • name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).

  • seed (int) – Random seed for environment initialization.

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.

observation_space#

Dict space with state keys and “image”.

Type:

gym.spaces.Dict

action_space#

Continuous action space from DMC spec.

Type:

gym.spaces.Box

Example

>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64))
>>> obs = env.reset()
>>> print(obs.keys())  # dict_keys(['position', 'velocity', 'image'])
property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
render(*args, **kwargs)[source]#
class world_models.envs.BraxImageEnv(env, seed=0, size=(64, 64), backend=None, episode_length=None, auto_reset=False, jit=True, suppress_warp_warnings=True, **env_kwargs)[source]#

Bases: object

Gym-like adapter for training TorchWM world models on Brax tasks.

Brax environments are functional JAX environments: reset consumes a PRNG key and returns a state, while step consumes the previous state plus an action and returns the next state. This adapter stores the Brax state between calls and converts state observations into image observations compatible with pixel-based TorchWM agents such as Dreamer.

If a Brax renderer is not available, vector observations are rendered as deterministic feature-band images so training code can still consume a pixel stream. The original vector observation is also exposed through info["vector_observation"] after step for diagnostics.

Parameters:
  • seed (int)

  • size (tuple[int, int])

  • backend (str | None)

  • episode_length (int | None)

  • auto_reset (bool)

  • jit (bool)

  • suppress_warp_warnings (bool)

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.make_brax_env(env, **kwargs)[source]#

Create a TorchWM image wrapper for Brax environments.

Parameters:
  • env – Brax environment name (for example, "ant") or a pre-built Brax environment object exposing reset(rng) and step(state, action).

  • **kwargs – Additional keyword arguments passed to BraxImageEnv.

Returns:

A Gym-like wrapper that returns {"image": (C, H, W)} observations and exposes continuous actions in the Brax [-1, 1] range.

Return type:

BraxImageEnv

class world_models.envs.TimeLimit(env, duration)[source]#

Bases: object

Terminate episodes after a fixed number of wrapper steps.

If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.

step(action)[source]#
reset()[source]#
class world_models.envs.ActionRepeat(env, amount)[source]#

Bases: object

Repeat each action for a fixed number of environment steps.

Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.

step(action)[source]#
class world_models.envs.NormalizeActions(env)[source]#

Bases: object

Expose a normalized [-1, 1] action space for bounded continuous controls.

Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.

property action_space#
step(action)[source]#
class world_models.envs.ObsDict(env, key='obs')[source]#

Bases: object

Convert scalar/array observations into a dictionary observation format.

This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).

property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.OneHotAction(env)[source]#

Bases: object

Wrap discrete-action environments to accept one-hot action vectors.

The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.

property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.RewardObs(env)[source]#

Bases: object

Augment observations with the latest scalar reward under obs[“reward”].

Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.

property observation_space#
step(action)[source]#
reset()[source]#
class world_models.envs.ResizeImage(env, size=(64, 64))[source]#

Bases: object

Resize image-like observation entries to a target spatial size.

The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.RenderImage(env, key='image')[source]#

Bases: object

Inject RGB renders from env.render(“rgb_array”) into observations.

This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.SelectAction(env, key)[source]#

Bases: Wrapper

Gym wrapper for dictionary actions that forwards a selected key only.

This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.

step(action)[source]#
world_models.envs.make_env(env_id, **kwargs)[source]#

Compatibility helper: create an environment by delegating to package-specific factories when available, falling back to gym.make.

This preserves older callers that expect make_env to exist.

Parameters:

env_id (str)

class world_models.envs.dmc.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#

Bases: object

Gym-style adapter for DeepMind Control Suite tasks.

The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.

Features:
  • Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)

  • Automatically handles special cases like “cup” -> “ball_in_cup”

  • Renders RGB images at configurable resolution

  • Returns observations as dict with both state vectors and images

Parameters:
  • name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).

  • seed (int) – Random seed for environment initialization.

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.

observation_space#

Dict space with state keys and “image”.

Type:

gym.spaces.Dict

action_space#

Continuous action space from DMC spec.

Type:

gym.spaces.Box

Example

>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64))
>>> obs = env.reset()
>>> print(obs.keys())  # dict_keys(['position', 'velocity', 'image'])
property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
render(*args, **kwargs)[source]#
world_models.envs.gym_env.make_gym_env(env, **kwargs)[source]#

Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.

Parameters:
  • env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • **kwargs – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)

Returns:

A wrapper that always returns image observations in the

format {“image”: (C, H, W)} suitable for pixel-based world models.

Return type:

GymImageEnv

class world_models.envs.gym_env.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#

Bases: object

Gym-like environment wrapper that always returns image observations.

This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.

Features:
  • Supports environment IDs (string) and pre-built environment objects.

  • Synthesizes RGB images from vector observations for pixel-based training.

  • Exposes continuous action spaces mapped to [-1, 1] range.

  • Converts discrete actions to one-hot vectors.

  • Returns observations as dict {“image”: (C, H, W)} with uint8 values.

Parameters:
  • env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • seed (int) – Random seed for environment reset (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • render_mode (str) – Render mode for environment (default: “rgb_array”).

observation_space#

Dict space with “image” key containing (C, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode (default: 1000).

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.ale_atari_env.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#

Create any Atari environment from Arcadic Learning Environment (ALE).

Parameters:
  • env_id (str) – The id of the Atari environment to create.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The created Atari environment.

Return type:

gym.Env

world_models.envs.ale_atari_env.list_available_atari_envs()[source]#

Get a list of all available Atari environments in Arcadic Learning Environment (ALE).

Returns:

List of available Atari environment IDs.

Return type:

list[str]

world_models.envs.ale_atari_vector_env.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#

Create vectorized Atari environments using ALE’s native AtariVectorEnv.

Parameters:
  • game (str) – The name of the Atari game (e.g., “pong”, “breakout”).

  • num_envs (int) – Number of parallel environments.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • seed (Optional[int]) – Random seed for reproducibility.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The vectorized Atari environment.

Return type:

AtariVectorEnv

world_models.envs.mujoco_env.make_mujoco_env_from_config(args, size)[source]#

Build a MuJoCo image environment from a DreamerConfig-like object.

Parameters:

size (tuple[int, int])

class world_models.envs.mujoco_env.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#

Bases: object

Native MuJoCo environment adapter for pixel-based world-model training.

The adapter uses the low-level mujoco Python package directly: models are compiled from MJCF XML strings/files or MJB binaries via mujoco.MjModel; simulation state lives in mujoco.MjData; actions are written to data.ctrl; and images are produced with mujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract: {"image": uint8[C, H, W]}.

Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply reward_fn and terminal_fn callbacks. By default, rewards are 0.0 and episodes terminate only through external wrappers such as TimeLimit.

Parameters:
  • xml_path (str | Path | None)

  • xml_string (str | None)

  • binary_path (str | Path | None)

  • assets (dict[str, bytes] | None)

  • seed (int)

  • size (tuple[int, int])

  • camera (str | int | None)

  • reward_fn (RewardFn | None)

  • terminal_fn (TerminalFn | None)

  • frame_skip (int)

  • reset_noise_scale (float)

  • default_control_range (tuple[float, float])

property observation_space#
property action_space#
reset(seed=None)[source]#
Parameters:

seed (int | None)

step(action)[source]#
render()[source]#
close()[source]#
world_models.envs.mujoco_env.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create one MuJoCo image environment factory for tasks and MJCF/MJB models.

Parameters:
  • model (str | Path | None) – Either a Gymnasium MuJoCo task id such as "Humanoid-v4", an MJCF XML path/string, or an MJB binary path.

  • backend (str) – "auto" infers native vs Gymnasium task mode. Use "native" for MJCF/MJB, "gymnasium" for task ids, or "robotics" for Gymnasium Robotics registrations.

  • seed (int) – Seed forwarded to the image wrapper.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make in task-id mode. Extra **kwargs are also forwarded there.

  • **kwargs – Native MuJoCoImageEnv options for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.

Returns:

A TorchWM image environment returning {"image": uint8[C, H, W]}.

world_models.envs.robotics_env.is_moved_mujoco_error(exc)[source]#

Return whether Gymnasium reported the v2/v3 MuJoCo move.

Parameters:

exc (BaseException)

Return type:

bool

world_models.envs.robotics_env.register_gymnasium_robotics_envs()[source]#

Import Gymnasium Robotics so its environments are registered with Gymnasium.

Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external gymnasium-robotics package. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely on gymnasium.register_envs; this helper supports both paths.

world_models.envs.robotics_env.list_gymnasium_robotics_envs()[source]#

List all Gymnasium Robotics ids registered by the installed package.

Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.

Return type:

list[str]

world_models.envs.robotics_env.make_gymnasium_env_with_robotics_fallback(env, *, render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create a Gymnasium env and retry after Robotics registration if needed.

Parameters:
  • env (str)

  • render_mode (str)

  • gym_kwargs (dict[str, Any] | None)

world_models.envs.robotics_env.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create a TorchWM image wrapper for a Gymnasium Robotics environment.

Parameters:
  • env (str) – Any environment id registered by gymnasium-robotics.

  • seed (int) – Seed forwarded to GymImageEnv.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode forwarded to gymnasium.make.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make.

  • **kwargs – Additional keyword arguments forwarded to gymnasium.make.

Returns:

A GymImageEnv that emits {"image": uint8[C, H, W]} observations.

world_models.envs.unity_env.make_unity_mlagents_env(env_id=None, **kwargs)[source]#

Create a Unity ML-Agents environment wrapper.

Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.

Parameters:
  • **kwargs – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).

  • env_id (str | None)

Returns:

A Gym-compatible wrapper for Unity environments.

Return type:

UnityMLAgentsEnv

class world_models.envs.unity_env.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#

Bases: object

Gym-like wrapper for Unity ML-Agents environments.

Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.

Features:
  • Supports single-agent control with continuous action spaces.

  • Returns observations as {“image”: (C, H, W)} with uint8 values.

  • Normalizes actions to [-1, 1] range.

  • Includes rendered frames in observations for visual policies.

Parameters:
  • file_name (str) – Path to the Unity environment binary.

  • behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.

  • seed (int) – Random seed for environment (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • worker_id (int) – Worker ID for multi-environment setup (default: 0).

  • base_port (int) – Base port for Unity environment communication (default: 5005).

  • no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).

  • time_scale (float) – Simulation time scale multiplier (default: 20.0).

  • quality_level (int) – Graphics quality level 0-5 (default: 1).

  • max_episode_steps (int) – Maximum steps per episode (default: 1000).

observation_space#

Dict space with “image” key containing (3, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode.

Raises:
  • ValueError – If no behaviors found or action space is not continuous.

  • RuntimeError – If no agents available after reset.

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
class world_models.envs.vector_env.SimWorker(worker_id, env_factory, num_envs, command_queue, result_queue, seed=None)[source]#

Bases: Process

Worker process that manages a batch of environment instances. Handles batched stepping for parallel rollouts.

Parameters:
  • worker_id (int)

  • env_factory (Callable)

  • num_envs (int)

  • command_queue (Queue)

  • result_queue (Queue)

  • seed (Optional[int])

run()[source]#

Main worker loop.

class world_models.envs.vector_env.VectorizedEnv(env_factory, num_workers=2, envs_per_worker=4, seed=None)[source]#

Bases: ABC

Abstract base class for vectorized environments. Manages multiple worker processes for parallel simulation.

Parameters:
  • env_factory (Callable)

  • num_workers (int)

  • envs_per_worker (int)

  • seed (Optional[int])

abstractmethod step_batch(actions)[source]#

Step all environments with batched actions.

Parameters:

actions (Tensor)

Return type:

Dict[str, Any]

abstractmethod reset_batch()[source]#

Reset all environments.

Return type:

Dict[str, Any]

render_batch()[source]#

Render all environments.

Return type:

List[ndarray]

close()[source]#

Shutdown all workers.

class world_models.envs.vector_env.TorchVectorizedEnv(*args, **kwargs)[source]#

Bases: VectorizedEnv

TorchWM-compatible vectorized environment. Returns batched tensors suitable for PyTorch training.

step_batch(actions)[source]#

Step all environments with batched actions.

Parameters:

actions (Tensor) – Tensor of shape (total_envs, action_dim)

Returns:

Dict with ‘obs’, ‘reward’, ‘done’, ‘info’ tensors

Return type:

Dict[str, Any]

reset_batch()[source]#

Reset all environments and return initial observations.

Return type:

Dict[str, Any]

class world_models.envs.wrappers.TimeLimit(env, duration)[source]#

Bases: object

Terminate episodes after a fixed number of wrapper steps.

If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.

step(action)[source]#
reset()[source]#
class world_models.envs.wrappers.ActionRepeat(env, amount)[source]#

Bases: object

Repeat each action for a fixed number of environment steps.

Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.

step(action)[source]#
class world_models.envs.wrappers.NormalizeActions(env)[source]#

Bases: object

Expose a normalized [-1, 1] action space for bounded continuous controls.

Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.

property action_space#
step(action)[source]#
class world_models.envs.wrappers.ObsDict(env, key='obs')[source]#

Bases: object

Convert scalar/array observations into a dictionary observation format.

This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).

property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.wrappers.OneHotAction(env)[source]#

Bases: object

Wrap discrete-action environments to accept one-hot action vectors.

The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.

property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.wrappers.RewardObs(env)[source]#

Bases: object

Augment observations with the latest scalar reward under obs[“reward”].

Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.

property observation_space#
step(action)[source]#
reset()[source]#
class world_models.envs.wrappers.ResizeImage(env, size=(64, 64))[source]#

Bases: object

Resize image-like observation entries to a target spatial size.

The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.wrappers.RenderImage(env, key='image')[source]#

Bases: object

Inject RGB renders from env.render(“rgb_array”) into observations.

This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.wrappers.UUID(env)[source]#

Bases: Wrapper

Gym wrapper that tracks a unique run identifier per environment reset.

The ID combines timestamp and UUID and can be used to tag episodes or artifacts generated during data collection.

reset()[source]#
class world_models.envs.wrappers.SelectAction(env, key)[source]#

Bases: Wrapper

Gym wrapper for dictionary actions that forwards a selected key only.

This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.

step(action)[source]#

Atari preprocessing helpers#

These helpers wrap Atari environments for specific training recipes. They are not separate environment families.

class world_models.envs.diamond_atari.DiamondAtariWrapper(env, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64))[source]#

Bases: Wrapper

Atari wrapper for DIAMOND following the paper specifications: - frameskip: number of frames to skip (default 4) - max_noop: maximum number of noop actions at reset (default 30) - terminate_on_life_loss: terminate episode when life is lost (default True) - reward_clip: clip rewards to [-1, 0, 1] (default True) - resize: resize observations to specified size (default 64x64)

Parameters:
  • env (Env)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (bool)

  • resize (Tuple[int, int] | None)

step(action)[source]#

Step the environment.

For backwards compatibility with older gym APIs this wrapper returns a 4-tuple: (obs, reward, done, info). Internally it supports gymnasium’s 5-tuple and collapses (terminated, truncated) into a single done bool.

Parameters:

action (int)

Return type:

Any

reset(**kwargs)[source]#
Return type:

Tuple[Any, Dict[str, Any]]

world_models.envs.diamond_atari.make_diamond_atari_env(game, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64), seed=None)[source]#

Create a DIAMOND-compatible Atari environment.

Parameters:
  • game (str) – Atari game name (e.g., “Breakout-v5”)

  • frameskip (int) – Number of frames to skip between actions

  • max_noop (int) – Maximum number of noop actions at reset

  • terminate_on_life_loss (bool) – Whether to terminate on life loss

  • reward_clip (bool) – Whether to clip rewards to [-1, 0, 1]

  • resize (Tuple[int, int]) – Target size for observations

  • seed (int | None) – Random seed

Returns:

Configured Atari environment

Return type:

DiamondAtariWrapper

Datasets and transforms#

class world_models.datasets.video_datasets.DatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True)[source]#

Bases: object

Base configuration for datasets.

Parameters:
  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • pin_memory (bool)

  • shuffle (bool)

num_frames: int = 16#
image_size: int = 64#
batch_size: int = 4#
num_workers: int = 4#
pin_memory: bool = True#
shuffle: bool = True#
class world_models.datasets.video_datasets.VideoDatasetBase(data_source, num_frames=16, image_size=64, transform=None, normalize=True)[source]#

Bases: Dataset

Base class for video datasets.

All video datasets should inherit from this class and implement the _load_video method.

Parameters:
  • data_source (str | Path | List[str] | List[Path])

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

data_source: str | Path | List[str] | List[Path]#
video_paths: Sequence[Path | int]#
class world_models.datasets.video_datasets.VideoFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.mp4', '.avi', '.mkv', '.webm', '.mov'), recursive=True)[source]#

Bases: VideoDatasetBase

Dataset that loads videos from a folder.

Supports common video formats: .mp4, .avi, .mkv, .webm

Usage:
dataset = VideoFolderDataset(

data_source=”/path/to/videos”, num_frames=16, image_size=64

)

Parameters:
  • data_source (str | Path | List[str] | List[Path])

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • extensions (Tuple[str, ...])

  • recursive (bool)

class world_models.datasets.video_datasets.ImageFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.jpg', '.jpeg', '.png', '.bmp'), image_sort_key=None)[source]#

Bases: VideoDatasetBase

Dataset that loads image sequences from folders.

Each subfolder is treated as a video sequence.

Usage:
dataset = ImageFolderDataset(

data_source=”/path/to/images”, num_frames=16, image_size=64

)

Parameters:
  • data_source (str | Path | List[str] | List[Path])

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • extensions (Tuple[str, ...])

  • image_sort_key (Callable | None)

class world_models.datasets.video_datasets.NumPyDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key=None)[source]#

Bases: VideoDatasetBase

Dataset that loads videos from numpy files.

Supports .npy and .npz files.

Usage:
dataset = NumPyDataset(

data_source=”/path/to/videos.npy”, num_frames=16, image_size=64

)

Parameters:
  • data_source (str | Path)

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • key (str | None)

class world_models.datasets.video_datasets.RLEnvironmentDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, obs_key='observations')[source]#

Bases: VideoDatasetBase

Dataset for RL environment recordings.

Loads trajectories stored as: - .npz files with ‘observations’ and ‘actions’ keys - Directory with episode folders

Usage:
dataset = RLEnvironmentDataset(

data_source=”/path/to/rl_episodes”, num_frames=16, image_size=64

)

Parameters:
  • data_source (str | Path)

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • obs_key (str)

class world_models.datasets.video_datasets.HDF5Dataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key='videos', memmap=False)[source]#

Bases: VideoDatasetBase

Dataset that loads videos from HDF5 files.

Supports pre-processed video datasets stored in HDF5 format. Expected structure: HDF5 file with ‘videos’ dataset of shape (N, T, H, W, C) or (N, T, C, H, W).

Usage:
dataset = HDF5Dataset(

data_source=”/path/to/videos.h5”, num_frames=16, image_size=64

)

Parameters:
  • data_source (str | Path)

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • key (str)

  • memmap (bool)

world_models.datasets.video_datasets.create_video_dataloader(dataset_type, data_source, num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, pin_memory=True, **kwargs)[source]#

Factory function to create video dataloaders.

Parameters:
  • dataset_type (str) – Type of dataset (“video_folder”, “image_folder”, “numpy”, “rl”)

  • data_source (str | Path | List[str]) – Path or list of paths to data

  • num_frames (int) – Number of frames per video

  • image_size (int) – Target image size (height and width)

  • batch_size (int) – Batch size for dataloader

  • num_workers (int) – Number of workers for data loading

  • shuffle (bool) – Whether to shuffle data

  • pin_memory (bool) – Whether to pin memory for faster GPU transfer

  • **kwargs – Additional arguments for specific dataset types

Returns:

Tuple of (dataset, dataloader)

Return type:

Tuple[Dataset, DataLoader]

Usage:
dataset, loader = create_video_dataloader(

dataset_type=”video_folder”, data_source=”/path/to/videos”, num_frames=16, image_size=64, batch_size=4

)

class world_models.datasets.video_datasets.VideoDatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True, dataset_type='video_folder', data_source='', extensions=('.mp4', '.avi', '.mkv'), recursive=True, obs_key='observations')[source]#

Bases: DatasetConfig

Configuration for video datasets.

Parameters:
  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • pin_memory (bool)

  • shuffle (bool)

  • dataset_type (str)

  • data_source (str)

  • extensions (Tuple[str, ...])

  • recursive (bool)

  • obs_key (str)

dataset_type: str = 'video_folder'#
data_source: str = ''#
extensions: Tuple[str, ...] = ('.mp4', '.avi', '.mkv')#
recursive: bool = True#
obs_key: str = 'observations'#
world_models.datasets.video_datasets.create_video_dataset_from_config(config)[source]#

Create video dataset and dataloader from config.

Parameters:

config (VideoDatasetConfig)

Return type:

Tuple[Dataset, DataLoader]

TinyWorlds Dataset Loaders

Loads pre-processed video datasets from HuggingFace for training Genie-style world models. Based on: AlmondGod/tinyworlds

Available datasets: - PICO_DOOM: Minimal Doom gameplay - PONG: Classic Pong - ZELDA: Zelda Ocarina of Time (2D) - POLE_POSITION: Racing game - SONIC: Sonic the Hedgehog

class world_models.datasets.tinyworlds.TinyWorldsConfig(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, cache_dir=None, split='train')[source]#

Bases: object

Configuration for TinyWorlds datasets.

Parameters:
  • dataset_name (str)

  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • cache_dir (str | None)

  • split (str)

dataset_name: str = 'SONIC'#
num_frames: int = 16#
image_size: int = 64#
batch_size: int = 4#
num_workers: int = 4#
cache_dir: str | None = None#
split: str = 'train'#
class world_models.datasets.tinyworlds.TinyWorldsDataset(dataset_name='SONIC', num_frames=16, image_size=64, split='train', cache_dir=None, download=True, data_file=None)[source]#

Bases: Dataset

Dataset for TinyWorlds game video data.

Loads pre-processed frames from HuggingFace datasets repository.

Parameters:
  • dataset_name (str)

  • num_frames (int)

  • image_size (int)

  • split (str)

  • cache_dir (str | None)

  • download (bool)

  • data_file (str | None)

DATASET_CONFIGS = {'PICO_DOOM': {'description': 'Minimal Doom gameplay', 'filename': 'picodoom_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'POLE_POSITION': {'description': 'Racing game', 'filename': 'pole_position_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'PONG': {'description': 'Classic Pong', 'filename': 'pong_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'SONIC': {'description': 'Sonic the Hedgehog', 'filename': 'sonic_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'ZELDA': {'description': 'Zelda Ocarina of Time (2D)', 'filename': 'zelda_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}}#
get_info()[source]#

Return dataset information.

Return type:

Dict

class world_models.datasets.tinyworlds.TinyWorldsDataLoader[source]#

Bases: object

Factory class for creating TinyWorlds dataloaders.

DATASET_NAMES = ['PICO_DOOM', 'PONG', 'ZELDA', 'POLE_POSITION', 'SONIC']#
static create_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#

Create a dataloader for TinyWorlds dataset.

Parameters:
  • dataset_name (str) – Name of the game dataset (PICO_DOOM, PONG, ZELDA, POLE_POSITION, SONIC)

  • num_frames (int) – Number of frames per video sequence

  • image_size (int) – Target image size (will resize frames)

  • batch_size (int) – Batch size

  • num_workers (int) – Number of data loading workers

  • shuffle (bool) – Whether to shuffle the data

  • cache_dir (str | None) – Directory to cache downloaded datasets

  • download (bool) – Whether to download if not cached

  • data_file (str | None)

Returns:

Tuple of (dataset, dataloader)

Return type:

Tuple[TinyWorldsDataset, DataLoader]

Usage:
dataset, loader = TinyWorldsDataLoader.create_dataloader(

dataset_name=”SONIC”, num_frames=16, image_size=64, batch_size=4

)

static list_available_datasets()[source]#

List all available dataset names.

Return type:

List[str]

static get_dataset_info(dataset_name)[source]#

Get information about a specific dataset without downloading.

Parameters:

dataset_name (str)

Return type:

Dict

world_models.datasets.tinyworlds.create_tinyworlds_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#

Factory function to create TinyWorlds dataloaders.

Parameters:
  • dataset_name (str) – Name of the game dataset (PICO_DOOM, PONG, ZELDA, POLE_POSITION, SONIC)

  • num_frames (int) – Number of frames per video sequence

  • image_size (int) – Target image size

  • batch_size (int) – Batch size

  • num_workers (int) – Number of data loading workers

  • shuffle (bool) – Whether to shuffle

  • cache_dir (str | None) – Cache directory for datasets

  • download (bool) – Download if not cached

  • data_file (str | None)

Returns:

Tuple of (dataset, dataloader)

Return type:

Tuple[TinyWorldsDataset, DataLoader]

Usage:
dataset, loader = create_tinyworlds_dataloader(

dataset_name=”SONIC”, num_frames=16, batch_size=4

)

for batch in loader:

# batch shape: (B, T, C, H, W) …

world_models.datasets.tinyworlds.download_all_datasets(cache_dir=None)[source]#

Download all available TinyWorlds datasets.

Parameters:

cache_dir (str | None) – Directory to cache downloaded datasets

Returns:

Dictionary mapping dataset names to local file paths

Return type:

Dict[str, str | None]

class world_models.datasets.diamond_dataset.ReplayBuffer(capacity=1000, obs_shape=(64, 64, 3), action_dim=1, device='cpu')[source]#

Bases: object

Replay buffer for storing environment interactions. Stores (observation, action, reward, done, next_observation) tuples.

Parameters:
  • capacity (int)

  • obs_shape (Tuple[int, int, int])

  • action_dim (int)

  • device (str)

add(obs, action, reward, done, next_obs)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray)

  • action (int)

  • reward (float)

  • done (bool)

  • next_obs (ndarray)

sample(batch_size)[source]#

Sample a random batch of transitions.

Parameters:

batch_size (int)

Return type:

Dict[str, Tensor]

sample_sequence(batch_size, sequence_length, burn_in=0)[source]#

Sample a sequence of transitions for training.

Parameters:
  • batch_size (int) – Number of sequences to sample

  • sequence_length (int) – Total sequence length (burn_in + horizon)

  • burn_in (int) – Number of initial frames to use for conditioning

Returns:

Dictionary with tensors of shape (batch_size, sequence_length, …)

Return type:

Dict[str, Tensor]

is_ready(min_size)[source]#

Check if buffer has enough samples.

Parameters:

min_size (int)

Return type:

bool

state_dict()[source]#

Return a serializable state dict for checkpointing.

Contains numpy arrays and scalar metadata so it can be saved with torch.save or numpy.save.

Return type:

dict

load_state_dict(state)[source]#

Load state previously produced by state_dict().

This will resize internal arrays if the saved capacity differs from the current buffer capacity.

Parameters:

state (dict)

class world_models.datasets.diamond_dataset.SequenceDataset(replay_buffer, sequence_length=5, burn_in=4)[source]#

Bases: Dataset

PyTorch Dataset for sampling sequences from the replay buffer. Used for training the diffusion world model.

Parameters:
  • replay_buffer (ReplayBuffer)

  • sequence_length (int)

  • burn_in (int)

world_models.datasets.diamond_dataset.collate_fn(batch)[source]#

Collate function for the dataloader.

Parameters:

batch (List[Dict[str, Tensor]])

Return type:

Dict[str, Tensor]

world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]#

Create CIFAR-10 dataset and distributed dataloader.

Factory function that creates a CIFAR-10 dataset with the provided transforms and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or diffusion training pipelines.

Parameters:
  • transform – Transforms to apply to images (e.g., RandomCrop, ColorJitter).

  • batch_size (int) – Number of samples per batch.

  • collator (callable, optional) – Custom collate function for batching (e.g., mask collator for JEPA).

  • pin_mem (bool) – Whether to pin memory for faster GPU transfer (default: True).

  • num_workers (int) – Number of data loading workers (default: 8).

  • world_size (int) – Number of distributed processes (default: 1).

  • rank (int) – Rank of current process in distributed setting (default: 0).

  • root_path (str, optional) – Path to store/load CIFAR-10 data.

  • drop_last (bool) – Whether to drop incomplete final batch (default: True).

  • train (bool) – Whether to load train or test split (default: True).

  • download (bool) – Whether to download dataset if not present (default: False).

Returns:

(dataset, dataloader, sampler)
  • dataset: torchvision.datasets.CIFAR10 instance

  • dataloader: torch.utils.data.DataLoader with distributed sampling

  • sampler: torch.utils.data.distributed.DistributedSampler

Return type:

tuple

Example

>>> transform = make_transforms(crop_size=224)
>>> dataset, loader, sampler = make_cifar10(
...     transform=transform,
...     batch_size=256,
...     root_path="./data",
...     download=True
... )
world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]#

Build an ImageNet-1K dataset and dataloader with distributed sampling.

Factory function that creates an ImageNet dataset and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or other self-supervised training pipelines.

Supports:
  • Optional data staging from network storage to local scratch

  • Subset filtering via text file listing allowed image IDs

  • Distributed sampling for multi-GPU training

Parameters:
  • transform (Any) – Transforms to apply to images.

  • batch_size (int) – Number of samples per batch.

  • collator (callable, optional) – Custom collate function (e.g., mask collator).

  • pin_mem (bool) – Whether to pin memory for GPU transfer (default: True).

  • num_workers (int) – Number of data loading workers (default: 8).

  • world_size (int) – Number of distributed processes (default: 1).

  • rank (int) – Rank of current process (default: 0).

  • root_path (str, optional) – Root path containing ImageNet data.

  • image_folder (str, optional) – Subfolder containing ImageNet data.

  • training (bool) – Load train or validation split (default: True).

  • copy_data (bool) – Copy data locally for faster loading (default: False).

  • drop_last (bool) – Drop incomplete final batch (default: True).

  • subset_file (str, optional) – Path to file listing allowed image IDs.

Returns:

(dataset, dataloader, sampler)
  • dataset: ImageNet dataset instance

  • dataloader: DataLoader with distributed sampling

  • sampler: DistributedSampler instance

Return type:

tuple

class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]#

Bases: ImageFolder

ImageNet dataset wrapper with optional local copy/extract workflow.

Extends torchvision.datasets.ImageFolder to support data staging from network storage to local scratch space for faster multi-process training on cluster environments (e.g., SLURM).

Features:
  • Optional data copying from network storage to local /scratch

  • Extracts tar archives automatically on first access

  • Supports train/validation splits

  • Optional target indexing for balanced sampling

class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]#

Bases: object

View over an ImageNet dataset filtered by an explicit image-id list.

The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.

filter_dataset_(subset_file)[source]#

Filter self.dataset to a subset

property classes#
world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]#

Copy and extract ImageNet archives to per-job local scratch storage.

In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.

world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]#

Create an ImageFolder dataset loader for custom folder-structured datasets.

Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.

Parameters:
  • transform (Any)

  • batch_size (int)

  • collator (Any)

  • pin_mem (bool)

  • num_workers (int)

  • world_size (int)

  • rank (int)

  • root_path (str | None)

  • image_folder (str | None)

  • drop_last (bool)

  • val_split (float | None)

Return type:

Tuple[Dataset, DataLoader, DistributedSampler]

world_models.transforms.transforms.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]#

Compose image augmentations and normalization for vision model training.

Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.

class world_models.transforms.transforms.GaussianBlur(p=0.5, radius_min=0.1, radius_max=2.0)[source]#

Bases: object

Probabilistic Gaussian blur augmentation for PIL images.

Applies blur with random radius in a configurable range when sampled.

Masking and JEPA helpers#

Masks sub-module - Masking strategies for JEPA and masked training.

This package provides various masking collator classes for generating encoder/predictor masks during masked representation learning.

Usage:

from world_models.masks import MaskCollator, DefaultCollator collator = MaskCollator(input_size=(64, 64), patch_size=8)

world_models.masks.MultiblockMaskCollator#

alias of MaskCollator

world_models.masks.RandomMaskCollator#

alias of MaskCollator

class world_models.masks.DefaultCollator[source]#

Bases: object

Simple collator that returns batch data and no masking metadata.

This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.

class world_models.masks.default.DefaultCollator[source]#

Bases: object

Simple collator that returns batch data and no masking metadata.

This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.

class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]#

Bases: object

Generate multi-block encoder and predictor masks for JEPA training.

For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.

step()[source]#
class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]#

Bases: object

Generate random context/prediction patch splits for masked training.

A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.

step()[source]#
world_models.helpers.jepa_helper.load_checkpoint(device, r_path, encoder, predictor, target_encoder, opt, scaler)[source]#

Load JEPA training state from disk into model and optimizer objects.

Restores encoder, predictor, optional target encoder, optimizer state, and optional AMP scaler, returning the resumed epoch for training restart.

world_models.helpers.jepa_helper.init_model(device, patch_size=16, model_name='vit_base', crop_size=224, pred_depth=6, pred_emb_dim=384)[source]#

Initialize JEPA encoder and predictor modules with ViT backbones.

Applies truncated-normal parameter initialization, moves modules to the requested device, and returns (encoder, predictor).

world_models.helpers.jepa_helper.init_opt(encoder, predictor, iterations_per_epoch, start_lr, ref_lr, warmup, num_epochs, wd=1e-06, final_wd=1e-06, final_lr=0.0, use_bfloat16=False, ipe_scale=1.25)[source]#

Build optimizer, AMP scaler, LR scheduler, and weight-decay scheduler for JEPA.

Parameters are grouped to exclude bias/norm tensors from weight decay, matching typical transformer training best practices.

Benchmarks and reports#

Benchmarks sub-module - Benchmark runners and adapters for world models.

This package provides tools for running standardized evaluations of world models (Dreamer, IRIS, DIAMOND) across multiple seeds and computing aggregate metrics.

Usage:

from world_models.benchmarks import BenchmarkRunner, DiamondAdapter runner = BenchmarkRunner(adapter_cls=DiamondAdapter, …)

class world_models.benchmarks.BenchmarkRunner(adapter_cls, out_dir='results')[source]#

Bases: object

Run evaluations for adapters across seeds and export results.

Usage:

runner = BenchmarkRunner(adapter_cls=adapters.DiamondAdapter) results = runner.run(games=[“Breakout-v5”], seeds=[0,1], episodes=5)

Parameters:
run(env_spec=None, seeds=None, num_episodes=5, checkpoint=None, extra_kwargs=None)[source]#

Run benchmark.

Returns a results dict with per-seed episode returns and computed metrics.

Parameters:
  • env_spec (Any | None)

  • seeds (List[int] | None)

  • num_episodes (int)

  • checkpoint (str | None)

  • extra_kwargs (Dict[str, Any] | None)

Return type:

Dict[str, Any]

class world_models.benchmarks.MultiAgentBenchmarkRunner(adapters, out_dir='results')[source]#

Bases: object

Run evaluations for multiple adapters on the same environment.

Usage:

runner = MultiAgentBenchmarkRunner(adapters=[adapters.DiamondAdapter, adapters.IRISAdapter]) results = runner.run_all(game=”Breakout-v5”, seeds=[0,1], episodes=5)

Parameters:
run_all(env_spec, seeds=None, num_episodes=5, checkpoints=None, extra_kwargs=None, train_epochs=None)[source]#

Run benchmarks for all adapters on the same environment.

Returns a results dict with results for each adapter.

Parameters:
  • env_spec (Dict[str, Any])

  • seeds (List[int] | None)

  • num_episodes (int)

  • checkpoints (Dict[str, str] | None)

  • extra_kwargs (Dict[str, Any] | None)

  • train_epochs (int | None)

Return type:

Dict[str, Any]

class world_models.benchmarks.BaseAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: object

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#

Return standardized output. Preferred format: dict with key ‘episode_returns’ -> List[float]

Parameters:
  • num_episodes (int)

  • render (bool)

class world_models.benchmarks.DreamerAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

class world_models.benchmarks.IRISAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

class world_models.benchmarks.DiamondAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

world_models.benchmarks.compute_aggregate_metrics(per_seed_means)[source]#
Parameters:

per_seed_means (Iterable[float])

Return type:

Dict[str, float]

world_models.benchmarks.bootstrap_ci(values, num_samples=1000, alpha=0.05)[source]#

Compute simple bootstrap 1-alpha CI on the mean.

Parameters:
  • values (List[float])

  • num_samples (int)

  • alpha (float)

world_models.benchmarks.iqm_of_array(values)[source]#

Compute the Interquartile Mean (IQM) of an array of values.

IQM is the mean of values that lie between the 25th and 75th percentiles (inclusive). This is a robust central tendency measure used in RL benchmark reporting.

Parameters:

values (Iterable[float])

Return type:

float

class world_models.benchmarks.runner.BenchmarkRunner(adapter_cls, out_dir='results')[source]#

Bases: object

Run evaluations for adapters across seeds and export results.

Usage:

runner = BenchmarkRunner(adapter_cls=adapters.DiamondAdapter) results = runner.run(games=[“Breakout-v5”], seeds=[0,1], episodes=5)

Parameters:
run(env_spec=None, seeds=None, num_episodes=5, checkpoint=None, extra_kwargs=None)[source]#

Run benchmark.

Returns a results dict with per-seed episode returns and computed metrics.

Parameters:
  • env_spec (Any | None)

  • seeds (List[int] | None)

  • num_episodes (int)

  • checkpoint (str | None)

  • extra_kwargs (Dict[str, Any] | None)

Return type:

Dict[str, Any]

class world_models.benchmarks.runner.MultiAgentBenchmarkRunner(adapters, out_dir='results')[source]#

Bases: object

Run evaluations for multiple adapters on the same environment.

Usage:

runner = MultiAgentBenchmarkRunner(adapters=[adapters.DiamondAdapter, adapters.IRISAdapter]) results = runner.run_all(game=”Breakout-v5”, seeds=[0,1], episodes=5)

Parameters:
run_all(env_spec, seeds=None, num_episodes=5, checkpoints=None, extra_kwargs=None, train_epochs=None)[source]#

Run benchmarks for all adapters on the same environment.

Returns a results dict with results for each adapter.

Parameters:
  • env_spec (Dict[str, Any])

  • seeds (List[int] | None)

  • num_episodes (int)

  • checkpoints (Dict[str, str] | None)

  • extra_kwargs (Dict[str, Any] | None)

  • train_epochs (int | None)

Return type:

Dict[str, Any]

class world_models.benchmarks.adapters.BaseAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: object

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#

Return standardized output. Preferred format: dict with key ‘episode_returns’ -> List[float]

Parameters:
  • num_episodes (int)

  • render (bool)

class world_models.benchmarks.adapters.DiamondAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

class world_models.benchmarks.adapters.IRISAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

class world_models.benchmarks.adapters.DreamerAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

load_checkpoint(path)[source]#
Parameters:

path (str)

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

class world_models.benchmarks.adapters.DreamerV1Adapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: DreamerAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

class world_models.benchmarks.adapters.DreamerV2Adapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: DreamerAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

world_models.benchmarks.metrics.compute_aggregate_metrics(per_seed_means)[source]#
Parameters:

per_seed_means (Iterable[float])

Return type:

Dict[str, float]

world_models.benchmarks.metrics.bootstrap_ci(values, num_samples=1000, alpha=0.05)[source]#

Compute simple bootstrap 1-alpha CI on the mean.

Parameters:
  • values (List[float])

  • num_samples (int)

  • alpha (float)

world_models.benchmarks.metrics.iqm_of_array(values)[source]#

Compute the Interquartile Mean (IQM) of an array of values.

IQM is the mean of values that lie between the 25th and 75th percentiles (inclusive). This is a robust central tendency measure used in RL benchmark reporting.

Parameters:

values (Iterable[float])

Return type:

float

world_models.benchmarks.metrics.bootstrap_iqm_ci(values, num_samples=1000, alpha=0.05)[source]#

Bootstrap a confidence interval for the IQM.

Returns (lower, upper) percentiles of the bootstrap IQM distribution.

Parameters:
  • values (List[float])

  • num_samples (int)

  • alpha (float)

world_models.benchmarks.reporting.export_csv(results, path)[source]#
Parameters:
  • results (Dict[str, Any])

  • path (str)

world_models.benchmarks.reporting.export_markdown(results, path)[source]#
Parameters:
  • results (Dict[str, Any])

  • path (str)

world_models.benchmarks.reporting.export_latex(results, path, caption='Benchmark results')[source]#
Parameters:
  • results (Dict[str, Any])

  • path (str)

  • caption (str)

Utilities#

world_models.utils.dreamer_utils.symlog(x)[source]#

Symmetric log transform used by Dreamer V2 for reward/value targets.

Defined as sign(x) * log(1 + |x|). This compresses large positive or negative values into a range that is easier to predict with a categorical distribution over a bounded set of buckets.

Parameters:

x (Tensor)

Return type:

Tensor

world_models.utils.dreamer_utils.symexp(x)[source]#

Inverse of symlog().

Defined as sign(x) * (exp(|x|) - 1).

Parameters:

x (Tensor)

Return type:

Tensor

class world_models.utils.dreamer_utils.TwoHotEncoder(num_buckets=255, symlog_range=10.0)[source]#

Bases: object

Two-hot encoding for symlog targets (Dreamer V2 reward/value heads).

A target value is softly assigned to the two nearest buckets on a uniform grid spanning [-symlog_range, symlog_range]. The categorical logits produced by a network can then be decoded back into a real value by computing the expected bucket center.

Parameters:
  • num_buckets (int) – Number of buckets in the categorical distribution.

  • symlog_range (float) – Maximum absolute value (in symlog space) covered by the grid. Values outside the range are clipped to the boundary buckets.

register_buffers()[source]#

Allocate the bucket-center buffer on CPU. Use to() to move.

Return type:

None

to(device)[source]#
Parameters:

device (device)

Return type:

TwoHotEncoder

encode(target)[source]#

Two-hot encode a real-valued target into soft bucket probabilities.

Parameters:

target (Tensor) – Tensor of arbitrary shape containing real-valued targets.

Returns:

Tensor with an extra final dimension of size num_buckets containing the soft two-hot distribution. The encoding assumes the target is already in symlog space, matching Dreamer V2.

Return type:

Tensor

decode(logits)[source]#

Decode categorical logits into the expected real-valued prediction.

The logits are first softmaxed and then combined with the bucket centers. The output is passed through symexp() to invert the symlog transform.

Parameters:

logits (Tensor) – Tensor with a final dimension of num_buckets.

Returns:

Tensor with the same shape as logits minus the last dimension.

Return type:

Tensor

world_models.utils.dreamer_utils.get_parameters(modules)[source]#

Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]#

Bases: object

Context manager that temporarily disables gradients for given modules.

Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.Logger(log_dir, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', video_format='gif', video_fps=20)[source]#

Bases: object

Experiment logger for scalars and GIF rollouts using WandB.

Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.

log_scalar(scalar, name, step_)[source]#
log_scalars(scalar_dict, step)[source]#
log_videos(videos, step, max_videos_to_save=1, fps=None, video_title='video')[source]#
dump_scalars_to_pickle(metrics, step, log_title=None)[source]#
flush()[source]#
world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]#

Compute TD(lambda) returns from imagined rewards, values, and discounts.

Implements backward recursion used by Dreamer actor/value objectives.

world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]#

Initialize a tensor in-place from a truncated normal distribution.

Values are sampled from N(mean, std) and clipped to [a, b].

world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]#

Repeat each batch chunk multiple times while preserving chunk ordering.

Used in JEPA masking code to align context and target token batches.

class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]#

Bases: object

Learning-rate schedule with linear warmup followed by cosine decay.

Updates optimizer parameter-group LRs on each call to step().

step()[source]#
class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]#

Bases: object

Cosine scheduler for optimizer weight decay values.

Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.

step()[source]#
world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]#

Measure CUDA execution time for a closure and return (result, elapsed_ms).

Falls back to -1 elapsed time when CUDA timing is unavailable.

class world_models.utils.jepa_utils.CSVLogger(fname, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', *argv)[source]#

Bases: object

Lightweight CSV logger with per-column printf-style formatting and WandB support.

log(step, *argv)[source]#
class world_models.utils.jepa_utils.AverageMeter[source]#

Bases: object

Track running statistics (val, avg, min, max, sum, count) for metrics.

reset()[source]#
update(val, n=1)[source]#
world_models.utils.jepa_utils.grad_logger(named_params)[source]#

Aggregate gradient norm statistics over model parameters for logging.

Also exposes first/last qkv-layer gradient norms when available.

world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]#

Initialize torch distributed process groups when environment supports it.

Returns (world_size, rank) and gracefully falls back to single-process mode.

class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]#

Bases: Function

Autograd-aware all-gather operation across distributed workers.

Forward concatenates worker tensors; backward reduces and slices gradients.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]#

Bases: Function

Autograd function that sums tensors across distributed workers in forward pass.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]#

Bases: Function

Autograd function that all-reduces and averages tensors across workers.

Used to synchronize scalar losses for consistent distributed logging/training.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
world_models.utils.data_utils.create_efficient_dataloader(dataset, batch_size, num_workers=None, pin_memory=True, prefetch_factor=2, persistent_workers=True)[source]#

Create a memory-efficient and fast DataLoader.

Parameters:
  • dataset (Dataset)

  • batch_size (int)

  • num_workers (int | None)

  • pin_memory (bool)

  • prefetch_factor (int)

  • persistent_workers (bool)

Return type:

DataLoader

world_models.utils.data_utils.prefetch_iterator(iterator, buffer_size=3)[source]#

Add prefetching to any iterator.

Parameters:
  • iterator (Iterator)

  • buffer_size (int)

world_models.utils.jit_utils.jit_compile_function(func)[source]#

JIT compile a function for performance.

Parameters:

func (Callable)

Return type:

Callable

world_models.utils.jit_utils.jit_compile_module(module)[source]#

JIT compile a PyTorch module.

Parameters:

module (Module)

Return type:

Module

world_models.utils.memory_utils.apply_gradient_checkpointing(model, checkpoint_ratio=0.5)[source]#

Apply gradient checkpointing to reduce memory usage during training.

Parameters:
  • model (Module)

  • checkpoint_ratio (float)

world_models.utils.memory_utils.enable_mixed_precision(model, scaler=None)[source]#

Enable mixed precision training.

Parameters:
  • model (Module)

  • scaler (GradScaler | None)

world_models.utils.memory_utils.optimize_memory_efficient_ops()[source]#

Set PyTorch for memory-efficient operations.

world_models.utils.logging_utils.setup_logging(name, level='INFO', log_file=None)[source]#

Set up consistent logging for torchwm modules.

Parameters:
  • name (str)

  • level (str)

  • log_file (str | None)

Return type:

Logger

class world_models.utils.utils.AttrDict[source]#

Bases: dict

world_models.utils.utils.load_yml_config(path)[source]#
world_models.utils.utils.to_tensor_obs(image)[source]#

Converts the input np img to channel first 64x64 dim torch img.

world_models.utils.utils.postprocess_img(image, depth)[source]#

Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])

world_models.utils.utils.preprocess_img(image, depth)[source]#

Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !!

world_models.utils.utils.bottle(func, *tensors)[source]#

Evaluates a func that operates in N x D with inputs of shape N x T x D

world_models.utils.utils.get_combined_params(*models)[source]#

Returns the combine parameter list of all the models given as input.

world_models.utils.utils.save_video(frames, path, name)[source]#

Saves a video containing frames.

Accepts frames in either:
  • (T, C, H, W) float in [0,1]

  • (T, H, W, C) float in [0,1]

Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout.

world_models.utils.utils.combine_videos(video_dir, output_name='combined.mp4', pattern='vid_*.mp4', fps=25, resize=True)[source]#

Combine all videos matching pattern in video_dir into a single MP4 file. Returns the output filepath (string).

Example

combine_videos(“results/planet”, output_name=”all_training.mp4”)

world_models.utils.utils.ensure_results_dir_exists(results_dir)[source]#

Simple helper to validate a results directory exists. Raises FileNotFoundError if not present.

world_models.utils.utils.save_frames(target, pred_prior, pred_posterior, name, n_rows=5)[source]#

Save side-by-side target, prior-prediction, and posterior-prediction frames.

The function accepts tensors with optional time dimension and writes a PNG grid to {name}.png. Spatial sizes are aligned per timestep before concatenation and values are normalized to [0, 1] when needed.

world_models.utils.utils.get_mask(tensor, lengths)[source]#

Build a batch-first validity mask from sequence lengths.

tensor may be a tensor/array with shape (N, T, ...) or (N,). The returned mask marks valid timesteps with ones up to each element in lengths and preserves device/dtype conventions from the input.

world_models.utils.utils.load_memory(path, device)[source]#

Loads an experience replay buffer (backwards-compatible with older pickle formats). Converts legacy list/.data formats into the current Memory(episodes) object.

world_models.utils.utils.flatten_dict(data, sep='.', prefix='')[source]#

Flattens a nested dict into a single-level dict.

Example

{‘a’: 2, ‘b’: {‘c’: 20}} -> {‘a’: 2, ‘b.c’: 20}

world_models.utils.utils.normalize_frames_for_saving(frames)[source]#

Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.

class world_models.utils.utils.TensorBoardMetrics(path)[source]#

Bases: object

Plots and (optionally) stores metrics for an experiment.

assign_type(key, val)[source]#
update(metrics)[source]#
Parameters:

metrics (dict)

world_models.utils.utils.apply_model(model, inputs, ignore_dim=None)[source]#

Placeholder helper for generic model application across input structures.

Currently not implemented; kept as an extension hook for future utility code.

world_models.utils.utils.plot_metrics(metrics, path, prefix)[source]#

Render and save line plots for each metric series in a dictionary.

world_models.utils.utils.lineplot(xs, ys, title, path='', xaxis='episode')[source]#

Create a Plotly line plot for scalar, dict, or ensemble-series data.

Supports uncertainty-band plotting when ys is a 2D array.

class world_models.utils.utils.TorchImageEnvWrapper(env, bit_depth, observation_shape=None, act_rep=2)[source]#

Bases: object

Torch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form.

reset()[source]#
step(u)[source]#
render()[source]#
close()[source]#
property observation_size#
property action_size#
sample_random_action()[source]#
property max_episode_steps#

Return environment max episode steps (compatible with TimeLimit/spec).

world_models.utils.utils.apply_masks(x, masks)[source]#

Gather token subsets from patch sequences using index masks.

Each mask selects token positions from x; selected groups are concatenated along the batch dimension.

world_models.utils.utils.visualize_latent_tsne(latents, labels=None, save_path=None, perplexity=30)[source]#

Visualize latent representations using t-SNE.

Parameters:
  • latents – torch.Tensor of shape (N, D) or numpy array

  • labels – optional list or array of labels for coloring

  • save_path – path to save the plot (HTML for plotly)

  • perplexity – t-SNE perplexity parameter

world_models.utils.utils.visualize_latent_umap(latents, labels=None, save_path=None, n_neighbors=15)[source]#

Visualize latent representations using UMAP.

Parameters:
  • latents – torch.Tensor of shape (N, D) or numpy array

  • labels – optional list or array of labels for coloring

  • save_path – path to save the plot (HTML for plotly)

  • n_neighbors – UMAP n_neighbors parameter

class world_models.utils.utils.StreamingVideoWriter(path, fps=20, frame_shape=None, format='mp4')[source]#

Bases: object

A class for streaming video writing to save frames in real-time.

Parameters:
  • path – output video file path

  • fps – frames per second

  • frame_shape – (height, width) of frames

  • format – ‘mp4’ or ‘avi’

write_frame(frame)[source]#

Write a single frame to the video.

Parameters:

frame – numpy array of shape (H, W, C) or (H, W), uint8 or float in [0,1]

close()[source]#