API Reference#

This reference is generated from source docstrings and grouped by workflow. Use World Models Study Guide for conceptual explanations and this page for exact classes, functions, and module-level APIs.

Public package surface#

These modules expose the most common imports and lazy constructors.

Primary module: torchwm. Implementation modules are documented below for API completeness.

Use torchwm for common workflows:

import torchwm
agent = torchwm.create_model("dreamer", env="walker-walk")

Friendly top-level package for TorchWM.

torchwm is the recommended public namespace for users:

import torchwm

agent = torchwm.create_model("dreamer", env="walker-walk")
class torchwm.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing an environment backend available through make_env.

Parameters:
  • name (str)

  • factory_path (str)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

factory_path: str#

Alias for field number 1

description: str#

Alias for field number 2

aliases: tuple[str, ...]#

Alias for field number 3

class torchwm.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing a model available through create_model().

Parameters:
  • name (str)

  • import_path (str)

  • config_path (str | None)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

import_path: str#

Alias for field number 1

config_path: str | None#

Alias for field number 2

description: str#

Alias for field number 3

aliases: tuple[str, ...]#

Alias for field number 4

torchwm.create_config(model, **overrides)[source]#

Create the default config object for model and apply overrides.

Examples

>>> cfg = create_config("dreamer", env="walker-walk", seed=7)
>>> cfg.env
'walker-walk'
Parameters:
  • model (str)

  • overrides (Any)

Return type:

Any

torchwm.create_model(model, config=None, **overrides)[source]#

Instantiate a model or agent from a simple string name.

config is optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.

Examples

>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000)
>>> genie = create_model("genie-small", image_size=32)
Parameters:
  • model (str)

  • config (Any | None)

  • overrides (Any)

Return type:

Any

torchwm.get_env_backend_spec(name)[source]#

Return metadata for an environment backend name or alias.

Parameters:

name (str)

Return type:

EnvBackendSpec

torchwm.get_model_spec(name)[source]#

Return metadata for a model name or alias.

Parameters:

name (str)

Return type:

ModelSpec

torchwm.list_env_backends()[source]#

Return canonical backend names accepted by make_env().

Return type:

list[str]

torchwm.list_envs(model=None)[source]#

List known environment ids, optionally filtered by model family.

Parameters:

model (str | None)

Return type:

list[str] | dict[str, list[str]]

torchwm.list_models()[source]#

Return canonical model names accepted by create_model().

Return type:

list[str]

torchwm.make_env(env_id, backend='auto', **kwargs)[source]#

Create an environment with a consistent TorchWM entry point.

Parameters:
  • env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.

  • backend (str) – One of list_env_backends(); "auto" tries TorchWM’s compatibility helper.

  • **kwargs (Any) – Backend-specific options.

Return type:

Any

torchwm.export_any(obj, path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export any TorchWM model/agent or a target module contained by it.

Parameters:
  • obj (Any)

  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • target (str | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

torchwm.export_model(module, path, format='onnx', *, example_inputs=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export a torch.nn.Module to ONNX, TorchScript, or TensorRT.

Parameters:
  • module (Module)

  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

class torchwm.ExportableAgentMixin[source]#

Bases: object

Mixin for non-nn.Module agents that delegates to the shared exporter.

export(path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export this agent or one of its contained modules for deployment.

Parameters:
  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • target (str | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

class torchwm.IRISAgent(config, action_size, device)[source]#

Bases: Module

Complete IRIS Agent with world model and policy.

Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning

Parameters:
  • config (IRISConfig)

  • action_size (int)

  • device (device)

classmethod from_config(config=None, *, action_size, device=None, **overrides)[source]#

Build an IRIS agent from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (IRISConfig | dict[str, Any] | str | Path | None)

  • action_size (int)

  • device (device | str | None)

  • overrides (Any)

Return type:

IRISAgent

classmethod from_pretrained(pretrained_model_name_or_path, *, action_size=None, device=None, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#

Load an IRIS agent checkpoint from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • action_size (int | None)

  • device (device | str | None)

  • config (IRISConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • overrides (Any)

Return type:

IRISAgent

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

forward_actor_critic(frames, hidden=None)[source]#

Forward pass through actor-critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state

Returns:

(B, T, action_size) values: (B, T) hidden_state: (h, c)

Return type:

action_logits

act(frame, epsilon=0.0, temperature=1.0)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • epsilon (float) – Random action probability

  • temperature (float) – Action distribution temperature

Returns:

Selected actions (B,)

Return type:

actions

imagine_rollout(initial_frame, horizon=20)[source]#

Generate imagined trajectories using world model.

Parameters:
  • initial_frame (Tensor) – Starting frame (B, C, H, W)

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined rollout data

Return type:

trajectory

update_autoencoder(frames)[source]#

Update discrete autoencoder.

Parameters:

frames (Tensor) – Training frames (B, C, H, W)

Returns:

Dictionary of loss values

Return type:

losses

update_transformer(frames, actions, rewards, terminals)[source]#

Update transformer world model.

Parameters:
  • frames (Tensor) – Frame sequence

  • actions (Tensor) – Actions taken

  • rewards (Tensor) – Rewards received

  • terminals (Tensor) – Terminal flags

Returns:

Dictionary of loss values

Return type:

losses

update_actor_critic(imagined_trajectory)[source]#

Update actor-critic in imagination.

Parameters:

imagined_trajectory (dict) – Dictionary from imagine_rollout

Returns:

Dictionary of loss values

Return type:

losses

save(path)[source]#

Save agent state.

Parameters:

path (str)

Return type:

None

load(path)[source]#

Load agent state.

Parameters:

path (str)

Return type:

None

torchwm.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#

Compute λ-return target for value function training.

Parameters:
  • rewards (Tensor) – Rewards (B, T)

  • values (Tensor) – Value estimates (B, T+1)

  • discounts (Tensor) – Discount factors (B, T)

  • lambda_coef (float) – Lambda parameter for bootstrapping

Returns:

λ-return targets (B, T)

Return type:

lambda_returns

class torchwm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#

Bases: Module

Modular RSSM with swappable encoder, decoder, and backbone.

This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer

Example

>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024)
>>> decoder = ConvDecoder(32, 200, (3, 64, 64))
>>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024)
>>> rssm = ModularRSSM(encoder, decoder, backbone)
Parameters:
property stoch_size: int#
property deter_size: int#
property embed_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

get_dist(mean, std)[source]#
Parameters:
  • mean (Tensor)

  • std (Tensor)

Return type:

Distribution

observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • obs (Tensor)

  • nonterm (Any)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • nonterm (Any)

Return type:

Dict[str, Tensor]

observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
Parameters:
  • obs (Tensor)

  • actions (Tensor)

  • nonterms (Tensor)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_rollout(actor, prev_state, horizon)[source]#
Parameters:
  • actor (Module)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Dict[str, Tensor]

decode_observation(features)[source]#
Parameters:

features (Tensor)

Return type:

Tensor

decode_reward(features)[source]#
Parameters:

features (Tensor)

Return type:

Tensor

detach_state(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

seq_to_batch(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

torchwm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#

Factory function to create a modular RSSM with specified components.

Parameters:
  • encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)

  • decoder_type (str) – Type of decoder (“conv”, “mlp”)

  • backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)

  • obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state

  • action_size (int) – Action space dimension

  • stoch_size (int) – Stochastic latent dimension

  • deter_size (int) – Deterministic hidden dimension

  • embed_size (int) – Encoder embedding dimension

  • hidden_size (int) – Hidden layer dimension

  • activation (str) – Activation function name

  • kwargs (Any)

Returns:

Configured ModularRSSM instance

Return type:

ModularRSSM

class torchwm.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Genie: Generative Interactive Environment.

A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions

Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).

Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)

The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse

At inference, latent actions are stopgrad’d when passed to dynamics model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_decoder_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • latent_action_depth (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

classmethod from_config(config=None, **overrides)[source]#

Build Genie from a config object, dict, YAML file, or YAML string.

Parameters:
Return type:

Genie

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Load Genie weights from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (GenieConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

Genie

save_pretrained(path)[source]#

Save Genie weights and config in a from_pretrained-compatible format.

Parameters:

path (str | Path)

Return type:

None

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

forward(video, mask_prob=0.5, training_phase='all')[source]#

Full forward pass through all components.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing losses and predictions

Return type:

Dict[str, Tensor]

training_step(video, mask_prob=0.5, training_phase='all')[source]#

Single training step computing all losses.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing all losses for backpropagation

Return type:

Dict[str, Tensor]

encode_video(video)[source]#

Encode video to discrete tokens.

Parameters:

video (Tensor) – (B, C, T, H, W)

Returns:

(B, T, H*W)

Return type:

video_tokens

infer_actions(frames)[source]#

Infer latent actions from a sequence of frames.

Parameters:

frames (Tensor) – (B, C, T, H, W) video frames

Returns:

(B, T-1) inferred latent action indices

Return type:

latent_actions

generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#

Generate video frames given a prompt frame and actions.

Parameters:
  • prompt_frame (Tensor) – (B, C, H, W) initial frame

  • num_frames (int) – Total number of frames to generate

  • actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random

  • use_maskgit (bool) – Whether to use MaskGIT sampling

Returns:

(B, C, num_frames, H, W)

Return type:

generated_video

play(current_frame, action, current_frames=None)[source]#

Play step - generate next frame given current frame and action.

Parameters:
  • current_frame (Tensor) – (B, C, H, W) current frame

  • action (Tensor) – (B,) latent action indices

  • current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame

Returns:

(B, C, H, W)

Return type:

next_frame

get_num_parameters()[source]#

Return total number of parameters.

Return type:

int

class torchwm.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Latent Action Model (LAM) for unsupervised action learning.

Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.

Based on Genie paper - learns actions without action labels from Internet videos.

Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

encode(x_prev, x_next)[source]#

Encode frames to latent actions.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)

Return type:

latent_actions

decode(x_prev, z_q)[source]#

Decode latent actions to predict next frame.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W) - will mask all but first

  • z_q (Tensor) – Quantized action embeddings (B, T-1, embedding_dim)

Returns:

(B, C, H, W)

Return type:

predicted_next_frame

forward(x_prev, x_next)[source]#

Full forward pass: encode to get actions, decode to reconstruct.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Dictionary with losses and outputs

Return type:

Dict[str, Tensor]

class torchwm.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=True)[source]#

Bases: Module

Dynamics Model for action-controllable video generation.

A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.

Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • gradient_checkpointing (bool)

forward(video_tokens, actions, mask_prob=0.0)[source]#

Forward pass for training.

Parameters:
  • video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T

  • actions (Tensor) – (B, T) - latent action indices for frames 1 to T

  • mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)

Returns:

(B, T, H*W, vocab_size)

Return type:

logits

sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#

Sample future frames using MaskGIT.

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • sampler (MaskGITSampler | None) – MaskGIT sampler instance

Returns:

(B, num_frames, N)

Return type:

generated_tokens

autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#

Simple autoregressive sampling (token by token).

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • temperature (float) – Sampling temperature

Returns:

(B, num_frames, N)

Return type:

generated_tokens

torchwm.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

torchwm.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Create a smaller Genie model for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

torchwm.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#

Create the full 11B parameter Genie model (approximate).

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

torchwm.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

LatentActionModel

torchwm.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#

Factory function to create a Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

Return type:

DynamicsModel

class torchwm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer for latent dynamics learning.

The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:

  1. Deterministic State (h) – A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.

  2. Stochastic State (s) – A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).

The model operates in two modes:

  • Observe Mode – Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)

  • Imagine Mode – Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)

Architecture

  • Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1}

  • Process: GRU updates deterministic state, MLP computes stochastic prior/posterior

  • Output: Updated state (h_t, s_t) and distributions

State Representation

  • deter (h): GRU hidden state, captures sequential context

  • stoch (s): Stochastic latent, multi-modal uncertainty

  • mean/std: Parameters of the stochastic distribution

Usage with DreamerAgent:

rssm = RSSM(
    action_size=action_dim,
    stoch_size=30,      # Stochastic state dimension
    deter_size=200,     # Deterministic (GRU) state dimension
    hidden_size=200,    # MLP hidden layer size
    obs_embed_size=256,  # Observation embedding from encoder
    activation='elu'
)

# Observe with actual observation
posterior = rssm.observe_step(prev_state, prev_action, obs_embed)

# Imagine future without observation
prior = rssm.imagine_step(current_state, action)

Training

The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to

capture environment dynamics

  • Reconstruction loss from decoder ensures state captures observation info

Reference:

Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • hidden_size (int)

  • obs_embed_size (int)

  • activation (str)

init_state(batch_size, device)[source]#

Initialize RSSM state with zeros.

Parameters:
  • batch_size (int) – Number of parallel sequences

  • device (device) – torch device for tensors

Returns:

  • mean, std: Stochastic distribution parameters

  • stoch: Stochastic state sample

  • deter: Deterministic GRU hidden state

Return type:

Dictionary containing zero-initialized state components

get_dist(mean, std)[source]#

Create an Independent Normal distribution from mean and std.

Parameters:
  • mean (Tensor) – Location parameter

  • std (Tensor) – Scale parameter

Returns:

Independent Normal distribution with given parameters

Return type:

Independent

observe_step(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#

Update state using actual observation (observe mode).

In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.

Parameters:
  • prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)

  • obs_embed (Tensor) – Observation embedding from encoder, shape (B, obs_embed_size)

  • nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

A tuple (posterior, prior) of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.

Return type:

Tuple[dict, dict]

imagine_step(prev_state, prev_action, nonterm=tensor(1.))[source]#

Predict next state without observation (imagine mode).

In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.

Parameters:
  • prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)

  • nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

  • deter: Predicted deterministic state

  • mean, std, stoch: Prior stochastic state distribution

Return type:

Dictionary with predicted state containing

get_prior(prev_state, prev_action, nonterm=tensor(1.))[source]#
Parameters:
  • prev_state (dict)

  • prev_action (Tensor)

  • nonterm (Tensor)

Return type:

dict

get_posterior(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#

Compute posterior distribution over stochastic state.

The posterior incorporates observation information to produce a more accurate state estimate.

Parameters:
  • prev_state (dict) – Previous state dictionary

  • prev_action (Tensor) – Previous action

  • obs_embed (Tensor) – Observation embedding

  • nonterm (Tensor) – Termination mask

Returns:

Dictionary with posterior state (observation-informed). Note that the previous-state shape (B, ...) is preserved; the batch dimension is not flattened.

Return type:

dict

detach_state(state)[source]#

Detach state tensors from computation graph.

Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.

Parameters:

state (dict) – State dictionary with tensor values

Returns:

Detached state dictionary

Return type:

dict

seq_to_batch(state_dict)[source]#

Convert sequence state to batch format.

Parameters:

state_dict (dict) – Dictionary with sequence-dimension tensors (T, B, …)

Returns:

Dictionary with batch-dimension tensors (B*T, …)

Return type:

dict

observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#

Process a sequence of observations (observe mode rollout).

At each timestep we run observe_step once to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.

Parameters:
  • obs_embed (Tensor) – Observation embeddings, shape (T+1, B, obs_embed_size)

  • actions (Tensor) – Actions, shape (T, B, action_size)

  • nonterms (Tensor) – Non-termination flags, shape (T, B, 1)

  • init_state (dict) – Initial state dictionary

  • seq_len (int) – Sequence length T

Returns:

Dictionary with prior states stacked along the time axis posterior: Dictionary with posterior states stacked along the time axis

Return type:

prior

imagine_rollout(policy, init_state, horizon)[source]#

Generate imagined trajectory using policy (imagine mode rollout).

Parameters:
  • policy (Module) – Actor network that outputs actions from state features

  • init_state (dict) – Initial state dictionary

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined states for each step

Return type:

dict

forward(x, u)[source]#

Forward pass for training (computes sequence of states).

Parameters:
  • x (Tensor) – Observations, shape (B, T+1, C, H, W)

  • u (Tensor) – Actions, shape (B, T, action_size)

Returns:

List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)

Return type:

states

class torchwm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#

Bases: Module

A Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.

Parameters:
  • action_size (int)

  • state_size (int)

  • latent_size (int)

  • hidden_size (int)

  • embed_size (int)

  • activation_function (str)

get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#

Returns the initial posterior given the observation.

Parameters:
  • enc (Tensor)

  • h_t (Tensor | None)

  • s_t (Tensor | None)

  • a_t (Tensor | None)

  • mean (bool)

Return type:

tuple[Tensor, Tensor]

deterministic_state_fwd(h_t, s_t, a_t)[source]#

Deterministic transition update.

Ensures a_t is 2D and matches batch dimension of h_t before concatenation. Accepts a_t shaped [B, action_size], [action_size] (expanded to [B, action_size]), or [B]/scalar (reshaped appropriately).

Parameters:
  • h_t (Tensor)

  • s_t (Tensor)

  • a_t (Tensor)

Return type:

Tensor

state_prior(h_t, sample=False)[source]#

Returns the prior distribution over the latent state given the deterministic state

Parameters:
  • h_t (Tensor)

  • sample (bool)

Return type:

tuple[Tensor, Tensor] | Tensor

state_posterior(h_t, e_t, sample=False)[source]#

Returns the state prior given the deterministic state and obs

Parameters:
  • h_t (Tensor)

  • e_t (Tensor)

  • sample (bool)

Return type:

tuple[Tensor, Tensor] | Tensor

pred_reward(h_t, s_t)[source]#
Parameters:
  • h_t (Tensor)

  • s_t (Tensor)

Return type:

Tensor

rollout_prior(act, h_t, s_t)[source]#
Parameters:
  • act (Tensor)

  • h_t (Tensor)

  • s_t (Tensor)

Return type:

tuple[Tensor, Tensor]

forward(x, u)[source]#

Forward through the RSSM for a batch of sequences.

Parameters:
  • x (Tensor) – Tensor [B, T+1, C, H, W] (observations including initial frame)

  • u (Tensor) – Tensor [B, T, action_size] (actions for T steps)

Returns:

list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]

Return type:

states

class torchwm.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).

  • Input: (B, C, H, W) raw images, values in [-0.5, 0.5]

  • Process: 4 convolutional layers with stride 2, halving spatial dimensions

  • Output: (B, embed_size) compact representation

The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.

Usage with Dreamer:

encoder = ConvEncoder(
    input_shape=(3, 64, 64),  # RGB 64x64 images
    embed_size=256,           # RSSM observation embedding size
    activation='relu'         # Activation function
)
obs_embedding = encoder(observation)  # (B, 256)
Parameters:
  • input_shape (tuple) – Tuple (C, H, W) for input images, typically (3, 64, 64)

  • embed_size (int) – Output embedding dimension, typically 256 or 1024

  • activation (str) – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)

  • depth (int) – Base channel depth for first layer (default 32)

forward(inputs)[source]#
Parameters:

inputs (Tensor)

Return type:

Tensor

class torchwm.CNNEncoder(embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) encoder for processing image inputs.

Parameters:
  • embedding_size (int)

  • activation_function (str)

forward(observation)[source]#
Parameters:

observation (Tensor)

Return type:

Tensor

class torchwm.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional decoder for reconstructing observations from latent states.

Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.

  • Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter)

  • Process: Dense projection + 4 transposed convolutions (upsampling 2x each)

  • Output: Independent Normal distribution over observation pixels

The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.

Returns torch.distributions.Independent(Normal(mean, std), len(shape)) allowing log_prob(observation) computation for reconstruction loss.

Usage in Dreamer world model:

decoder = ConvDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(3, 64, 64),  # RGB images
    activation='relu'
)
obs_dist = decoder(latent_features)  # Returns distribution
log_prob = obs_dist.log_prob(target_observation)

The reconstruction loss is -log_prob(observation), which encourages the RSSM to learn states that capture observation information.

Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (tuple[int, ...])

  • activation (str)

  • depth (int)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Independent

class torchwm.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) decoder for reconstructing image outputs.

Parameters:
  • state_size (int)

  • latent_size (int)

  • embedding_size (int)

  • activation_function (str)

forward(latent, state)[source]#
Parameters:
  • latent (Tensor)

  • state (Tensor)

Return type:

Tensor

class torchwm.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.

  • Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter)

  • Process: MLP with configurable layers and hidden units

  • Output: Predicted quantity with distribution (normal, binary, or raw)

Supports three output types: - 'normal': Gaussian distribution for regression (rewards, values) - 'binary': Bernoulli distribution for binary classification (discount) - 'none': Raw tensor for non-probabilistic outputs

Usage:

reward_decoder = DenseDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(1,),
    n_layers=2,
    units=400,
    activation='elu',
    dist='normal'
)
reward_dist = reward_decoder(latent_features)
reward_loss = -reward_dist.log_prob(target_reward)

For discount prediction (binary):

discount_decoder = DenseDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(1,),
    n_layers=2,
    units=400,
    activation='elu',
    dist='binary'
)
Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (tuple[int, ...])

  • n_layers (int)

  • units (int)

  • activation (str)

  • dist (str)

  • num_buckets (int)

  • symlog_range (float)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Any

class torchwm.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • n_layers (int)

  • units (int)

  • activation (str)

  • min_std (float)

  • init_std (float)

  • mean_scale (float)

forward(features, deter=False)[source]#
Parameters:
  • features (Tensor)

  • deter (bool)

Return type:

Tensor

add_exploration(action, action_noise=0.3)[source]#
Parameters:
  • action (Tensor)

  • action_noise (float)

Return type:

Tensor

class torchwm.TanhBijector[source]#

Bases: Transform

Bijective tanh transform for squashing Gaussian distributions to [-1, 1].

This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:

  1. Bijective mapping: tanh is invertible (with atanh as inverse)

  2. Stable log-det Jacobian: Computable for gradient-based training

  3. Clipped actions: During inference, actions are naturally bounded

  • Forward: y = tanh(x)

  • Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y))

  • Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))

Usage with Dreamer ActionDecoder:

dist = TransformedDistribution(
    Normal(mean, std),
    TanhBijector()
)
action = dist.sample()  # Bounded to [-1, 1]
Reference:

Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.

property sign: int#
atanh(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

log_abs_det_jacobian(x, y)[source]#
Parameters:
  • x (Tensor)

  • y (Tensor)

Return type:

Tensor

class torchwm.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

Parameters:
  • dist (Any)

  • samples (int)

property name: str#
mean()[source]#
Return type:

Tensor

mode()[source]#
Return type:

Tensor

entropy()[source]#
Return type:

Tensor

sample()[source]#
Return type:

Tensor

class torchwm.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Encoder for IRIS discrete autoencoder.

Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.

Architecture:
  • 4 convolutional layers with residual blocks

  • Self-attention at 8x8 and 16x16 resolutions

  • Vector quantization to produce discrete tokens

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • in_channels (int)

  • base_channels (int)

  • num_residual_blocks (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Encode images to discrete tokens.

Parameters:

x (Tensor) – Input images (B, C, H, W) - should be 64x64

Returns:

Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

encode_to_indices(x)[source]#

Encode directly to token indices (for world model).

Parameters:

x (Tensor)

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode token indices to embeddings (for decoder).

Parameters:

indices (Tensor)

Return type:

Tensor

class torchwm.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Decoder for IRIS discrete autoencoder.

Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • base_channels (int)

  • out_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(z)[source]#

Decode tokens to images.

Parameters:

z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)

Returns:

Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)

Return type:

reconstructed

decode_from_embeddings(z_flat)[source]#

Decode flattened token embeddings to images.

Parameters:

z_flat (Tensor) – Flattened tokens (B, H*W, C) or (B, C, H, W)

Returns:

Reconstructed images

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode discrete token indices into images.

Parameters:

indices (Tensor) – Tensor of shape (B, H, W) or (B, H*W) containing integer token indices in the range [0, vocab_size).

Returns:

Reconstructed images (B, C, H, W)

Return type:

Tensor

class torchwm.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#

Bases: Module

Video Tokenizer using VQ-VAE with Spatiotemporal Transformer.

This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.

The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.

Architecture

  1. Patch Embedding: Convert (B, C, T, H, W) video to patch tokens

  2. Encoder ST-Transformer: Process spatial-temporal patches

  3. Vector Quantization: Discretize continuous embeddings to codebook entries

  4. Decoder ST-Transformer: Reconstruct video from quantized tokens

  5. Patch Unembedding: Convert tokens back to video frames

Key Features

  • Causal processing: Each frame’s encoding only uses previous frames

  • Discrete tokens: Enables autoregressive prediction with latent actions

  • Memory efficient: Uses ST-Transformer instead of full ViT to reduce complexity

Usage with Genie:

tokenizer = VideoTokenizer(
    num_frames=16,
    image_size=64,
    patch_size=4,
    vocab_size=1024,
    embedding_dim=32
)
reconstructed, indices, loss_dict = tokenizer(video_frames)

# For discrete token input to dynamics model:
token_embeddings = tokenizer.decode_indices(indices)

The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings - Commitment loss: Penalizes encoder outputs drifting from codebook

Reference:

Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • use_ema (bool)

  • ema_decay (float)

encode(x)[source]#

Encode video to discrete tokens.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices to embeddings for video frames.

Parameters:

indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’ x W’

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim)

Return type:

z_q

decode(z_q)[source]#

Decode discrete tokens to video frames.

Parameters:

z_q (Tensor) – Quantized embeddings (B, T, H’, W’, embedding_dim)

Returns:

Reconstructed video (B, C, T, H, W)

Return type:

Tensor

forward(x)[source]#

Full forward pass with VQ-VAE objective.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Reconstructed video (B, C, T, H, W) indices: Token indices (B, T, H’, W’) loss_dict: Dictionary containing loss components

Return type:

reconstructed

torchwm.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#

Factory function to create a Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

Return type:

VideoTokenizer

class torchwm.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#

Bases: Module

Vector Quantizer for discrete autoencoder.

Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)

Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

forward(z)[source]#

Quantize the input latents.

Parameters:

z (Tensor) – Input tensor of shape (B, C, H, W) or (B, C)

Returns:

Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices back to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class torchwm.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#

Bases: Module

Vector Quantizer with Exponential Moving Average updates.

Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • ema_decay (float)

  • epsilon (float)

forward(z)[source]#

Quantize with EMA updates.

Parameters:

z (Tensor)

Return type:

Tuple[Tensor, Tensor, dict[str, Tensor]]

decode_indices(indices)[source]#

Decode token indices to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class torchwm.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer with image observations and transitions.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.

Key Features

  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores raw uint8 images to save memory

  • Samples sequences (not single transitions) for temporal modeling

  • Validates sampled sequences don’t span episode boundaries

Memory Layout

  • observations: (capacity, C, H, W) uint8 images

  • actions: (capacity, action_dim) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)

Sampling Process

  1. Random start index (avoiding episode boundaries)

  2. Collect sequence of length seq_len with wraparound

  3. Validate no terminal in middle of sequence

  4. Return batch of sequences

Usage with Dreamer:

buffer = ReplayBuffer(
    size=100000,           # Max transitions to store
    obs_shape=(3, 64, 64), # RGB images
    action_size=6,         # Continuous action dim
    seq_len=50,            # Sequence length for training
    batch_size=50          # Parallel sequences per batch
)

# Add transitions during interaction
buffer.add(obs, action, reward, done)

# Sample batch for world model training
obs_batch, action_batch, reward_batch, term_batch = buffer.sample()

Memory Efficiency

  • Uses uint8 for images (1 byte per pixel vs 4 for float32)

  • Sequences share observations (overlapping windows)

  • Configurable capacity based on available system memory

Note

The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.

Parameters:
  • size (int)

  • obs_shape (Tuple[int, ...])

  • action_size (int)

  • seq_len (int)

  • batch_size (int)

add(obs, ac, rew, done)[source]#

Add a transition to the buffer.

Parameters:
  • obs (dict) – Observation dict with ‘image’ key containing the observation

  • ac (ndarray) – Action taken, shape (action_size,)

  • rew (float) – Reward received, scalar

  • done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise

Return type:

None

sample()[source]#

Sample a batch of sequences for training.

Returns:

(observations, actions, rewards, terminals)
  • observations: (seq_len, batch, C, H, W)

  • actions: (seq_len, batch, action_dim)

  • rewards: (seq_len, batch)

  • terminals: (seq_len, batch)

Return type:

tuple

class torchwm.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.

  • Stores complete episodes as lists of transitions

  • Samples contiguous sub-sequences for sequence models

  • Supports time-major formatting (time-first) for RNN input

  • Memory usage estimation to prevent OOM errors

Parameters:

size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).

episodes#

Collection of Episode objects.

Type:

deque

eps_lengths#

Length of each episode.

Type:

deque

size#

Total number of transitions across all episodes.

Type:

property

Example:

memory = Memory(size=100)
memory.append([episode1, episode2])
batch, lengths = memory.sample(batch_size=32, tracelen=50)
property size: int#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

Return type:

None

sample(batch_size, tracelen=1, time_first=False)[source]#

Sample random sub-sequences from stored episodes.

Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.

Parameters:
  • batch_size (int) – Number of sequences to sample.

  • tracelen (int) – Length of each sequence (default: 1).

  • time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).

Returns:

(observations, actions, rewards, terminals, lengths)
  • observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)

  • actions: (batch, tracelen, action_dim) or (tracelen, batch, …)

  • rewards: (batch, tracelen) or (tracelen, batch)

  • terminals: (batch, tracelen) or (tracelen, batch)

  • lengths: (batch,) original episode lengths for each sample

Return type:

tuple

Raises:
  • ValueError – If memory is empty or no episodes meet minimum length.

  • MemoryError – If estimated memory usage exceeds 200 MiB threshold.

class torchwm.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode.

Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.

x#

Observations collected during the episode.

Type:

list or np.ndarray

u#

Actions taken.

Type:

list or np.ndarray

r#

Rewards received.

Type:

list or np.ndarray

t#

Terminal flags (0.0 = continue, 1.0 = terminal).

Type:

list or np.ndarray

info#

Additional episode metadata.

Type:

dict

Parameters:

postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.

Example:

episode = Episode()
episode.append(obs, action, reward, False)
episode.append(obs, action, reward, True)
episode.terminate(final_obs)
print(episode.x.shape)  # Now a numpy array
property size: int#
append(obs, act, reward, terminal)[source]#
Parameters:
  • obs (Any)

  • act (Any)

  • reward (Any)

  • terminal (Any)

Return type:

None

terminate(obs)[source]#
Parameters:

obs (Any)

Return type:

None

class torchwm.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#

Bases: object

Replay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.

Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores uint8 images for memory efficiency

  • Samples sequences with validation to avoid episode boundaries

  • Supports sequence sampling for temporal learning

Memory Layout:
  • observations: (capacity, C, H, W) uint8

  • actions: (capacity, action_size) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32

Parameters:
  • size (int) – Maximum number of transitions to store.

  • obs_shape (tuple) – Shape of observations as (C, H, W).

  • action_size (int) – Dimension of actions.

  • seq_len (int) – Length of sequences to sample (default: 20).

  • batch_size (int) – Number of sequences per batch (default: 64).

size#

Buffer capacity.

Type:

int

obs_shape#

Observation shape.

Type:

tuple

action_size#

Action dimension.

Type:

int

seq_len#

Sequence length.

Type:

int

batch_size#

Batch size.

Type:

int

steps#

Total transitions added.

Type:

int

episodes#

Number of episode terminations observed.

Type:

int

add(obs, action, reward, terminal)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray) – Observation array with shape (C, H, W).

  • action (ndarray) – Action array with shape (action_size,).

  • reward (float) – Scalar reward value.

  • terminal (bool) – Boolean indicating if episode terminated.

Return type:

None

sample_sequence(seq_len=None)[source]#

Sample a batch of sequences for world model training.

Returns:

(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)

Return type:

observations

Parameters:

seq_len (int | None)

sample_single()[source]#

Sample a single transition for online updates.

Return type:

Tuple[ndarray, ndarray, float, float]

property buffer_capacity: int#

Returns the total capacity of the buffer.

class torchwm.IRISOnPolicyBuffer(max_steps=1000)[source]#

Bases: object

On-policy buffer for collecting trajectories during environment interaction.

Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.

Useful for:
  • Collecting complete episode trajectories

  • Storing data before batch processing

  • Temporary storage during environment interaction

Parameters:

max_steps (int) – Maximum number of steps to store (default: 1000).

max_steps#

Maximum buffer capacity.

Type:

int

observations#

List of observations.

Type:

list

actions#

List of actions.

Type:

list

rewards#

List of rewards.

Type:

list

terminals#

List of terminal flags.

Type:

list

add(obs, action, reward, terminal)[source]#
Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

Return type:

None

clear()[source]#
Return type:

None

get_arrays()[source]#
Return type:

tuple[ndarray, ndarray, ndarray, ndarray]

class torchwm.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

Parameters:
  • img_size (int)

  • patch_size (int)

  • in_channels (int)

  • d_model (int)

  • depth (int)

  • heads (int)

  • drop (float)

  • t_dim (int)

forward(x, t)[source]#
Parameters:
  • x (Tensor)

  • t (Tensor)

Return type:

Tensor

classmethod from_config(config=None, **overrides)[source]#

Build DiT from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (DiTConfig | dict[str, Any] | str | Path | None)

  • overrides (Any)

Return type:

DiT

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Load DiT weights from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (DiTConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

DiT

save_pretrained(path)[source]#

Save DiT weights and config in a from_pretrained-compatible format.

Parameters:

path (str | Path)

Return type:

None

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
Parameters:
  • epochs (int)

  • dataset (Any)

  • batch_size (int)

  • lr (float)

  • img_size (int)

  • channels (int)

  • patch (int)

  • width (int)

  • depth (int)

  • heads (int)

  • drop (float)

  • timesteps (int)

  • beta_start (float)

  • beta_end (float)

  • ema (bool)

  • ema_decay (float)

  • workdir (str)

  • root_path (str)

  • image_folder (str | None)

  • crop_size (int)

  • download (bool)

  • copy_data (bool)

  • subset_file (str | None)

  • val_split (float | None)

Return type:

None

torchwm.create_dit(config=None, **overrides)[source]#

Create a DiT from a DiTConfig or keyword overrides.

The public factory API works with config objects, while DiT itself has a compact constructor. This adapter keeps the lower-level model constructor unchanged and maps the public config fields onto the expected arguments.

Parameters:
  • config (Any)

  • overrides (Any)

Return type:

DiT

class torchwm.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.

Process:
  1. Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches

  2. Each patch is projected to embed_dim via linear layer (Conv2d)

  3. Learnable positional embeddings are added for spatial information

Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)

Parameters:
  • img_size (int) – Image size (assumes square), e.g., 32 for CIFAR

  • patch_size (int) – Size of each patch (typically 4, 8, or 16)

  • in_channels (int) – Number of input channels (3 for RGB)

  • embed_dim (int) – Output dimension for each patch token

Usage with DiT:

patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class torchwm.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

Parameters:
  • img_size (int)

  • patch_size (int)

  • embed_dim (int)

  • out_channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class torchwm.DDPM(timesteps, beta_start, beta_end)[source]#

Bases: Module

Utility module implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

Parameters:
  • timesteps (int)

  • beta_start (float)

  • beta_end (float)

q_sample(x_start, t, noise=None)[source]#
Parameters:
  • x_start (Tensor)

  • t (Tensor)

  • noise (Tensor | None)

Return type:

Tensor

p_sample(model, x_t, t)[source]#
Parameters:
  • model (Module)

  • x_t (Tensor)

  • t (Tensor)

Return type:

Tensor

sample(model, n, img_size, channels)[source]#
Parameters:
  • model (Module)

  • n (int)

  • img_size (int)

  • channels (int)

Return type:

Tensor

class torchwm.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#

Bases: Module

Actor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.

Parameters:
  • obs_channels (int)

  • action_dim (int)

  • channels (Tuple[int, ...])

  • lstm_dim (int)

forward(obs, hidden_state=None)[source]#

Forward pass of actor-critic network.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)

Return type:

policy_logits

get_action(obs, hidden_state=None, deterministic=False)[source]#

Get action from a single observation.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

  • deterministic (bool) – If True, take argmax; else sample

Returns:

Selected action [B] hidden_state: (h, c)

Return type:

action

get_actions(obs, hidden_state=None, deterministic=False)[source]#

Batched version of get_action.

Parameters:
  • obs (Tensor) – Tensor of shape [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size

  • deterministic (bool) – If True, take argmax; else sample from policy

Returns:

LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple

Return type:

actions

get_value(obs, hidden_state=None)[source]#

Get value for a single observation.

Parameters:
  • obs (Tensor)

  • hidden_state (Tuple[Tensor, Tensor] | None)

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor] | None]

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_hidden_size()[source]#

Get LSTM hidden size.

Return type:

int

class torchwm.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#

Bases: Module

Reward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.

Parameters:
  • obs_channels (int) – Number of observation channels (3 for RGB)

  • action_dim (int) – Number of possible actions

  • channels (Tuple[int, ...]) – List of channel sizes for conv blocks

  • lstm_dim (int) – LSTM hidden dimension

  • cond_dim (int) – Conditioning dimension for adaptive norm

forward(obs, actions, hidden_state=None)[source]#

Forward pass of reward/termination model.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • actions (Tensor) – Actions [B, T]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states

Return type:

reward_logits

predict(obs, actions, hidden_state=None)[source]#

Predict reward and termination for a single step.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • actions (Tensor) – Single action [B]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states

Return type:

reward

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

torchwm.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.

Math:

embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)

Parameters:
  • timesteps (Tensor) – Tensor of integer timesteps, shape (B,) or (B, 1)

  • dim (int) – Embedding dimension (must be even)

Returns:

Tensor of shape (B, dim) with sinusoidal embeddings

Return type:

Tensor

Usage with DiT:

t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)

# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization

class torchwm.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, gradient_checkpointing=False)[source]#

Bases: Module

Spatiotemporal Transformer for video modeling.

Contains L spatiotemporal blocks with interleaved spatial and temporal attention.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

  • gradient_checkpointing (bool)

forward(x)[source]#
Parameters:

x (Tensor) – (B, T*N, C) where T is num_frames, N is num_patches_per_frame

Returns:

(B, T*N, C)

Return type:

Tensor

class torchwm.MultiHeadSelfAttention(d, n_heads=2)[source]#

Bases: Module

Multi-head scaled dot-product self-attention over sequence tokens.

This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.

Parameters:
  • d (int)

  • n_heads (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

torchwm.MultiHeadAttention#

alias of MultiHeadSelfAttention

class torchwm.AdaLNNormalization(d_model, t_dim)[source]#

Bases: Module

Adaptive layer normalization conditioned on an external embedding.

The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).

Parameters:
  • d_model (int)

  • t_dim (int)

forward(x, t_emb)[source]#
Parameters:
  • x (Tensor)

  • t_emb (Tensor)

Return type:

Tensor

class torchwm.RMSNorm(dim, eps=1e-06)[source]#

Bases: Module

Root Mean Square Layer Normalization with a learned gain parameter.

RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.

Parameters:
  • dim (int)

  • eps (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class torchwm.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller using Cross-Entropy Method (CEM) with RSSM.

Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.

The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.

Algorithm:
  1. Initialize Gaussian distribution over action sequences

  2. Sample N candidate action sequences

  3. Rollout each sequence in RSSM latent space

  4. Score by predicted cumulative rewards

  5. Keep top K candidates, fit Gaussian to them

  6. Repeat for T iterations

  7. Execute first action from best sequence

Parameters:
  • model (Any)

  • planning_horizon (int)

  • num_candidates (int)

  • num_iterations (int)

  • top_candidates (int)

  • device (device | str)

rssm#

The RSSM world model.

N#

Number of candidate action sequences to sample.

K#

Number of top candidates to use for updating the proposal.

T#

Number of CEM iterations per planning step.

H#

Planning horizon (number of future steps to consider).

d#

Action dimensionality.

device#

Device to run computations on.

state_size#

Hidden state dimensionality.

latent_size#

Latent state dimensionality.

Example

>>> policy = RSSMPolicy(
...     model=rssm,
...     planning_horizon=12,
...     num_candidates=1000,
...     num_iterations=5,
...     top_candidates=100,
...     device='cuda'
... )
>>> policy.reset()
>>> action = policy.poll(observation)
reset()[source]#

Reset the policy state.

Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.

Return type:

None

poll(observation, explore=False)[source]#

Get action for given observation.

Parameters:
  • observation (Tensor) – Current observation tensor of shape (channels, height, width).

  • explore (bool) – If True, add exploration noise to the selected action.

Returns:

Action tensor of shape (1, action_size).

Return type:

Tensor

class torchwm.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Actor network for IRIS (Imagined Rollouts with Implicit Successor) policy.

Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.

Architecture:
  • CNN: Extracts features from input frames (3x64x64 -> 512)

  • LSTM: Processes temporal sequences with configurable layers

  • Linear: Maps hidden states to action logits

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

action_size#

Number of discrete actions.

Type:

int

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

forward(frames, hidden_state=None, burn_in_frames=None)[source]#

Forward pass through actor.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state

  • burn_in_frames (Tensor | None) – Frames to use for initializing hidden state

Returns:

Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple

Return type:

action_logits

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_action(frame, temperature=1.0, deterministic=False)[source]#

Get action from a single frame.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • temperature (float) – Softmax temperature (higher = more random)

  • deterministic (bool) – If True, return argmax; else sample

Returns:

Selected action indices (B,)

Return type:

action

class torchwm.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Critic network for IRIS value estimation.

Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.

Architecture:
  • CNN: Shared feature extractor with actor (3x64x64 -> 512)

  • LSTM: Temporal processing with same architecture as actor

  • Linear: Maps hidden states to scalar values

Parameters:
  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Returns:

Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.

Return type:

values

Parameters:
  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None)[source]#

Forward pass through critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple

Returns:

Value estimates (B, T) hidden_state: Updated (h, c) tuple

Return type:

values

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class torchwm.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Combined policy module for IRIS (Imagined Rollouts with Implicit Successor).

Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

actor#

The actor network for action selection.

Type:

IRISActor

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Example

>>> policy = IRISPolicy(
...     action_size=18,
...     hidden_size=512,
...     num_layers=4,
...     frame_shape=(3, 64, 64)
... )
>>> action = policy.act(frame, temperature=1.0, deterministic=False)
forward(frames)[source]#

Get action logits from frames.

Parameters:

frames (Tensor)

Return type:

Tensor

act(frame, temperature=1.0, deterministic=False)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor)

  • temperature (float)

  • deterministic (bool)

Return type:

Tensor

init_hidden(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

tuple[Tensor, Tensor]

class torchwm.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#

Bases: Module

CNN feature extractor shared between actor and critic networks.

Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.

Architecture:
  • Conv layers: 32 -> 64 -> 128 -> 256 channels

  • Each conv has stride=2 for spatial downsampling

  • Final linear layer projects to desired output dimension

Parameters:
  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

  • output_size (int) – Size of output feature vector (default: 512).

frame_shape#

Input frame shape.

Type:

tuple

output_size#

Output feature dimension.

Type:

int

Returns:

Feature vectors with shape (B, output_size).

Return type:

features

Parameters:
  • frame_shape (Tuple[int, int, int])

  • output_size (int)

forward(x)[source]#

Extract features from frames.

Parameters:

x (Tensor) – Frames (B, C, H, W)

Returns:

Feature vectors (B, output_size)

Return type:

features

class torchwm.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#

Bases: SerializableConfigMixin

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

Parameters:
  • env_backend (str)

  • env (str)

  • env_instance (Any)

  • image_size (tuple[int, int])

  • gym_render_mode (str)

  • dmlab_action_repeat (int)

  • dmlab_action_set (Any)

  • dmlab_observations (Any)

  • dmlab_config (Any)

  • dmlab_renderer (str)

  • procgen_distribution_mode (str)

  • procgen_num_levels (int)

  • procgen_start_level (Any)

  • mujoco_xml_path (Any)

  • mujoco_xml_string (Any)

  • mujoco_binary_path (Any)

  • mujoco_camera (Any)

  • mujoco_frame_skip (int)

  • mujoco_reset_noise_scale (float)

  • brax_backend (str)

  • brax_jit (bool)

  • brax_auto_reset (bool)

  • brax_suppress_warp_warnings (bool)

  • unity_file_name (Any)

  • unity_behavior_name (Any)

  • unity_worker_id (int)

  • unity_base_port (int)

  • unity_no_graphics (bool)

  • unity_time_scale (float)

  • unity_quality_level (int)

  • algo (str)

  • exp_name (str)

  • train (bool)

  • evaluate (bool)

  • seed (int)

  • no_gpu (bool)

  • max_episode_length (int)

  • buffer_size (int)

  • time_limit (int)

  • cnn_activation_function (str)

  • dense_activation_function (str)

  • obs_embed_size (int)

  • num_units (int)

  • deter_size (int)

  • stoch_size (int)

  • action_repeat (int)

  • action_noise (float)

  • total_steps (int)

  • seed_steps (int)

  • update_steps (int)

  • collect_steps (int)

  • batch_size (int)

  • train_seq_len (int)

  • imagine_horizon (int)

  • use_disc_model (bool)

  • free_nats (float)

  • discount (float)

  • td_lambda (float)

  • kl_loss_coeff (float)

  • kl_alpha (float)

  • disc_loss_coeff (float)

  • num_buckets (int)

  • symlog_range (float)

  • model_learning_rate (float)

  • actor_learning_rate (float)

  • value_learning_rate (float)

  • adam_epsilon (float)

  • grad_clip_norm (float)

  • use_amp (bool)

  • test (bool)

  • test_interval (int)

  • test_episodes (int)

  • scalar_freq (int)

  • log_video_freq (int)

  • max_videos_to_save (int)

  • video_format (str)

  • video_fps (int)

  • checkpoint_interval (int)

  • checkpoint_path (str)

  • restore (bool)

  • experience_replay (str)

  • render (bool)

  • enable_wandb (bool)

  • wandb_api_key (str)

  • wandb_project (str)

  • wandb_entity (str)

  • log_dir (str)

  • logdir (Any)

  • data_dir (Any)

  • log_level (str)

  • log_file (Any)

  • enable_tensorboard (bool)

  • enable_console_metrics (bool)

  • enable_jsonl (bool)

  • jsonl_filename (str)

  • log_system_stats_freq (int)

  • detect_anomaly (bool)

env_backend: str = 'dmc'#
env: str = 'walker-walk'#
env_instance: Any = None#
image_size: tuple[int, int] = (64, 64)#
gym_render_mode: str = 'rgb_array'#
dmlab_action_repeat: int = 4#
dmlab_action_set: Any = None#
dmlab_observations: Any = None#
dmlab_config: Any = None#
dmlab_renderer: str = 'hardware'#
procgen_distribution_mode: str = 'easy'#
procgen_num_levels: int = 0#
procgen_start_level: Any = None#
mujoco_xml_path: Any = None#
mujoco_xml_string: Any = None#
mujoco_binary_path: Any = None#
mujoco_camera: Any = None#
mujoco_frame_skip: int = 1#
mujoco_reset_noise_scale: float = 0.0#
brax_backend: str = 'generalized'#
brax_jit: bool = True#
brax_auto_reset: bool = False#
brax_suppress_warp_warnings: bool = True#
unity_file_name: Any = None#
unity_behavior_name: Any = None#
unity_worker_id: int = 0#
unity_base_port: int = 5005#
unity_no_graphics: bool = True#
unity_time_scale: float = 20.0#
unity_quality_level: int = 1#
algo: str = 'Dreamerv1'#
exp_name: str = 'lr1e-3'#
train: bool = True#
evaluate: bool = False#
seed: int = 1#
no_gpu: bool = False#
max_episode_length: int = 1000#
buffer_size: int = 800000#
time_limit: int = 1000#
cnn_activation_function: str = 'relu'#
dense_activation_function: str = 'elu'#
obs_embed_size: int = 1024#
num_units: int = 400#
deter_size: int = 200#
stoch_size: int = 30#
action_repeat: int = 2#
action_noise: float = 0.3#
total_steps: int = 5000000#
seed_steps: int = 5000#
update_steps: int = 100#
collect_steps: int = 1000#
batch_size: int = 50#
train_seq_len: int = 50#
imagine_horizon: int = 15#
use_disc_model: bool = False#
free_nats: float = 3.0#
discount: float = 0.99#
td_lambda: float = 0.95#
kl_loss_coeff: float = 1.0#
kl_alpha: float = 0.8#
disc_loss_coeff: float = 10.0#
num_buckets: int = 255#
symlog_range: float = 10.0#
model_learning_rate: float = 0.0006#
actor_learning_rate: float = 8e-05#
value_learning_rate: float = 8e-05#
adam_epsilon: float = 1e-07#
grad_clip_norm: float = 100.0#
use_amp: bool = True#
test: bool = False#
test_interval: int = 10000#
test_episodes: int = 10#
scalar_freq: int = 1000#
log_video_freq: int = -1#
max_videos_to_save: int = 2#
video_format: str = 'gif'#
video_fps: int = 20#
checkpoint_interval: int = 10000#
checkpoint_path: str = ''#
restore: bool = False#
experience_replay: str = ''#
render: bool = False#
enable_wandb: bool = False#
wandb_api_key: str = ''#
wandb_project: str = 'torchwm'#
wandb_entity: str = ''#
log_dir: str = 'runs'#
logdir: Any = None#
data_dir: Any = None#
log_level: str = 'INFO'#
log_file: Any = None#
enable_tensorboard: bool = False#
enable_console_metrics: bool = True#
enable_jsonl: bool = True#
jsonl_filename: str = 'metrics.jsonl'#
log_system_stats_freq: int = 1000#
detect_anomaly: bool = False#
class torchwm.JEPAConfig[source]#

Bases: SerializableConfigMixin

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

classmethod from_dict(values)[source]#

Load flat field values or the nested trainer dictionary.

Parameters:

values (Dict[str, Any])

Return type:

JEPAConfig

to_train_dict()[source]#

Return the nested dictionary expected by train_jepa.

Return type:

Dict[str, Dict[str, Any]]

to_nested_dict()[source]#

Backward-compatible alias for the nested JEPA trainer dictionary.

Return type:

Dict[str, Dict[str, Any]]

class torchwm.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: SerializableConfigMixin

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via __getattr__ and get_dit_config().

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
torchwm.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Both UPPER_CASE and snake_case override keys are accepted.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)

Parameters:

overrides (Any)

Return type:

DiTConfig

class torchwm.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: SerializableConfigMixin

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • data_loader_num_workers (int)

  • pin_memory (bool)

  • persistent_workers (bool)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • use_amp (bool)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
data_loader_num_workers: int = 4#
pin_memory: bool = True#
persistent_workers: bool = True#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
use_amp: bool = True#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#
class torchwm.IRISConfig[source]#

Bases: SerializableConfigMixin

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
Return type:

tuple

get_autoencoder_config()[source]#
Return type:

dict

get_transformer_config()[source]#
Return type:

dict

get_rl_config()[source]#
Return type:

dict

class torchwm.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class torchwm.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class torchwm.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class torchwm.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: SerializableConfigMixin

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class torchwm.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: SerializableConfigMixin

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class torchwm.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class torchwm.OperatorABC(*, device=None)[source]#

Bases: Module, ABC

Structured base class for inference operators.

Operators use a consistent pipeline:

  1. preprocess converts raw inputs into tensors.

  2. forward performs model/operator-specific tensor computation.

  3. postprocess formats the final output mapping.

Subclasses may also declare input_specs and output_specs to validate required tensor keys, shapes, and dtypes. OperatorABC inherits from torch.nn.Module, so operators support to(device), train(), and eval() just like model modules.

Parameters:

device (torch.device | str | None)

input_specs: Mapping[str, TensorSpec] = {}#
output_specs: Mapping[str, TensorSpec] = {}#
abstractmethod preprocess(inputs)[source]#

Convert raw inputs into a tensor mapping ready for forward.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

forward(inputs)[source]#

Run tensor computation for this operator.

Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.

Parameters:

inputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

postprocess(outputs)[source]#

Format validated forward outputs for consumers.

Parameters:

outputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

process(inputs)[source]#

Process raw inputs through preprocess, forward, and postprocess stages.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

batch(inputs)[source]#

Preprocess a sequence of inputs and stack matching tensor keys.

Parameters:

inputs (Sequence[Any])

Return type:

dict[str, Tensor]

to(*args, **kwargs)[source]#

Move module parameters/buffers and remember the target tensor device.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

OperatorABC

classmethod validate_mapping(values, specs, *, label)[source]#

Validate tensor keys, shapes, and dtypes against optional specs.

Parameters:
  • values (Mapping[str, Tensor])

  • specs (Mapping[str, TensorSpec])

  • label (str)

Return type:

None

class torchwm.TensorSpec(shape=None, dtype=None, required=True)[source]#

Bases: object

Optional tensor contract used to validate operator inputs or outputs.

Parameters:
  • shape (tuple[int | None, ...] | None) – Expected shape. Use None as a wildcard for dimensions that may vary, such as batch size.

  • dtype (dtype | None) – Expected tensor dtype.

  • required (bool) – Whether the key must be present in the mapping being validated.

shape: tuple[int | None, ...] | None = None#
dtype: dtype | None = None#
required: bool = True#
class torchwm.DreamerOperator(image_size=64, action_dim=6)[source]#

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

preprocess(inputs)[source]#

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class torchwm.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

preprocess(inputs)[source]#

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class torchwm.IrisOperator(seq_length=512, vocab_size=32000)[source]#

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

preprocess(inputs)[source]#

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class torchwm.PlaNetOperator(state_dim=32, action_dim=4)[source]#

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

preprocess(inputs)[source]#

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

torchwm.get_operator(name, **kwargs)[source]#

Factory function to get inference operators by name.

Parameters:
  • name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’

  • **kwargs (Any) – Operator-specific configuration

Returns:

Configured OperatorABC instance

Return type:

OperatorABC

Example

>>> op = get_operator('dreamer', image_size=64, action_dim=6)
>>> processed = op.process({'image': image, 'action': action})
class torchwm.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Predict scalar rewards from Dreamer latent belief and state vectors.

Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.

Parameters:
  • belief_size (int)

  • state_size (int)

  • hidden_size (int)

  • activation_function (str)

forward(belief, state)[source]#
Parameters:
  • belief (Tensor)

  • state (Tensor)

Return type:

Tensor

class torchwm.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Estimate scalar value from Dreamer latent belief and state vectors.

This MLP is trained on imagined returns and used for actor/value updates.

Parameters:
  • belief_size (int)

  • state_size (int)

  • hidden_size (int)

  • activation_function (str)

forward(belief, state)[source]#
Parameters:
  • belief (Tensor)

  • state (Tensor)

Return type:

Tensor

torchwm.DreamerRewardModel#

alias of RewardModel

torchwm.DreamerValueModel#

alias of ValueModel

Primary modules: world_models, world_models.models, world_models.configs, world_models.catalog, world_models.envs, and world_models.inference.

TorchWM public API.

This package keeps imports lightweight while still exposing a friendly top-level surface. Common workflows can use the small factory helpers:

import torchwm

cfg = torchwm.create_config("dreamer", env="walker-walk")
agent = torchwm.create_model("dreamer", cfg)
env = torchwm.make_env("CartPole-v1", backend="gym")

Lower-level research components remain available as lazy top-level exports, for example from torchwm import DreamerAgent, ConvEncoder, ReplayBuffer.

class world_models.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing an environment backend available through make_env.

Parameters:
  • name (str)

  • factory_path (str)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

factory_path: str#

Alias for field number 1

description: str#

Alias for field number 2

aliases: tuple[str, ...]#

Alias for field number 3

class world_models.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing a model available through create_model().

Parameters:
  • name (str)

  • import_path (str)

  • config_path (str | None)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

import_path: str#

Alias for field number 1

config_path: str | None#

Alias for field number 2

description: str#

Alias for field number 3

aliases: tuple[str, ...]#

Alias for field number 4

world_models.create_config(model, **overrides)[source]#

Create the default config object for model and apply overrides.

Examples

>>> cfg = create_config("dreamer", env="walker-walk", seed=7)
>>> cfg.env
'walker-walk'
Parameters:
  • model (str)

  • overrides (Any)

Return type:

Any

world_models.create_model(model, config=None, **overrides)[source]#

Instantiate a model or agent from a simple string name.

config is optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.

Examples

>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000)
>>> genie = create_model("genie-small", image_size=32)
Parameters:
  • model (str)

  • config (Any | None)

  • overrides (Any)

Return type:

Any

world_models.get_env_backend_spec(name)[source]#

Return metadata for an environment backend name or alias.

Parameters:

name (str)

Return type:

EnvBackendSpec

world_models.get_model_spec(name)[source]#

Return metadata for a model name or alias.

Parameters:

name (str)

Return type:

ModelSpec

world_models.list_env_backends()[source]#

Return canonical backend names accepted by make_env().

Return type:

list[str]

world_models.list_envs(model=None)[source]#

List known environment ids, optionally filtered by model family.

Parameters:

model (str | None)

Return type:

list[str] | dict[str, list[str]]

world_models.list_models()[source]#

Return canonical model names accepted by create_model().

Return type:

list[str]

world_models.make_env(env_id, backend='auto', **kwargs)[source]#

Create an environment with a consistent TorchWM entry point.

Parameters:
  • env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.

  • backend (str) – One of list_env_backends(); "auto" tries TorchWM’s compatibility helper.

  • **kwargs (Any) – Backend-specific options.

Return type:

Any

world_models.export_any(obj, path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export any TorchWM model/agent or a target module contained by it.

Parameters:
  • obj (Any)

  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • target (str | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

world_models.export_model(module, path, format='onnx', *, example_inputs=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export a torch.nn.Module to ONNX, TorchScript, or TensorRT.

Parameters:
  • module (Module)

  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

class world_models.ExportableAgentMixin[source]#

Bases: object

Mixin for non-nn.Module agents that delegates to the shared exporter.

export(path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export this agent or one of its contained modules for deployment.

Parameters:
  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • target (str | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

class world_models.IRISAgent(config, action_size, device)[source]#

Bases: Module

Complete IRIS Agent with world model and policy.

Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning

Parameters:
  • config (IRISConfig)

  • action_size (int)

  • device (device)

classmethod from_config(config=None, *, action_size, device=None, **overrides)[source]#

Build an IRIS agent from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (IRISConfig | dict[str, Any] | str | Path | None)

  • action_size (int)

  • device (device | str | None)

  • overrides (Any)

Return type:

IRISAgent

classmethod from_pretrained(pretrained_model_name_or_path, *, action_size=None, device=None, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#

Load an IRIS agent checkpoint from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • action_size (int | None)

  • device (device | str | None)

  • config (IRISConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • overrides (Any)

Return type:

IRISAgent

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

forward_actor_critic(frames, hidden=None)[source]#

Forward pass through actor-critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state

Returns:

(B, T, action_size) values: (B, T) hidden_state: (h, c)

Return type:

action_logits

act(frame, epsilon=0.0, temperature=1.0)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • epsilon (float) – Random action probability

  • temperature (float) – Action distribution temperature

Returns:

Selected actions (B,)

Return type:

actions

imagine_rollout(initial_frame, horizon=20)[source]#

Generate imagined trajectories using world model.

Parameters:
  • initial_frame (Tensor) – Starting frame (B, C, H, W)

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined rollout data

Return type:

trajectory

update_autoencoder(frames)[source]#

Update discrete autoencoder.

Parameters:

frames (Tensor) – Training frames (B, C, H, W)

Returns:

Dictionary of loss values

Return type:

losses

update_transformer(frames, actions, rewards, terminals)[source]#

Update transformer world model.

Parameters:
  • frames (Tensor) – Frame sequence

  • actions (Tensor) – Actions taken

  • rewards (Tensor) – Rewards received

  • terminals (Tensor) – Terminal flags

Returns:

Dictionary of loss values

Return type:

losses

update_actor_critic(imagined_trajectory)[source]#

Update actor-critic in imagination.

Parameters:

imagined_trajectory (dict) – Dictionary from imagine_rollout

Returns:

Dictionary of loss values

Return type:

losses

save(path)[source]#

Save agent state.

Parameters:

path (str)

Return type:

None

load(path)[source]#

Load agent state.

Parameters:

path (str)

Return type:

None

world_models.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#

Compute λ-return target for value function training.

Parameters:
  • rewards (Tensor) – Rewards (B, T)

  • values (Tensor) – Value estimates (B, T+1)

  • discounts (Tensor) – Discount factors (B, T)

  • lambda_coef (float) – Lambda parameter for bootstrapping

Returns:

λ-return targets (B, T)

Return type:

lambda_returns

class world_models.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#

Bases: Module

Modular RSSM with swappable encoder, decoder, and backbone.

This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer

Example

>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024)
>>> decoder = ConvDecoder(32, 200, (3, 64, 64))
>>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024)
>>> rssm = ModularRSSM(encoder, decoder, backbone)
Parameters:
property stoch_size: int#
property deter_size: int#
property embed_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

get_dist(mean, std)[source]#
Parameters:
  • mean (Tensor)

  • std (Tensor)

Return type:

Distribution

observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • obs (Tensor)

  • nonterm (Any)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • nonterm (Any)

Return type:

Dict[str, Tensor]

observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
Parameters:
  • obs (Tensor)

  • actions (Tensor)

  • nonterms (Tensor)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_rollout(actor, prev_state, horizon)[source]#
Parameters:
  • actor (Module)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Dict[str, Tensor]

decode_observation(features)[source]#
Parameters:

features (Tensor)

Return type:

Tensor

decode_reward(features)[source]#
Parameters:

features (Tensor)

Return type:

Tensor

detach_state(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

seq_to_batch(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

world_models.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#

Factory function to create a modular RSSM with specified components.

Parameters:
  • encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)

  • decoder_type (str) – Type of decoder (“conv”, “mlp”)

  • backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)

  • obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state

  • action_size (int) – Action space dimension

  • stoch_size (int) – Stochastic latent dimension

  • deter_size (int) – Deterministic hidden dimension

  • embed_size (int) – Encoder embedding dimension

  • hidden_size (int) – Hidden layer dimension

  • activation (str) – Activation function name

  • kwargs (Any)

Returns:

Configured ModularRSSM instance

Return type:

ModularRSSM

class world_models.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Genie: Generative Interactive Environment.

A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions

Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).

Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)

The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse

At inference, latent actions are stopgrad’d when passed to dynamics model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_decoder_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • latent_action_depth (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

classmethod from_config(config=None, **overrides)[source]#

Build Genie from a config object, dict, YAML file, or YAML string.

Parameters:
Return type:

Genie

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Load Genie weights from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (GenieConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

Genie

save_pretrained(path)[source]#

Save Genie weights and config in a from_pretrained-compatible format.

Parameters:

path (str | Path)

Return type:

None

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

forward(video, mask_prob=0.5, training_phase='all')[source]#

Full forward pass through all components.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing losses and predictions

Return type:

Dict[str, Tensor]

training_step(video, mask_prob=0.5, training_phase='all')[source]#

Single training step computing all losses.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing all losses for backpropagation

Return type:

Dict[str, Tensor]

encode_video(video)[source]#

Encode video to discrete tokens.

Parameters:

video (Tensor) – (B, C, T, H, W)

Returns:

(B, T, H*W)

Return type:

video_tokens

infer_actions(frames)[source]#

Infer latent actions from a sequence of frames.

Parameters:

frames (Tensor) – (B, C, T, H, W) video frames

Returns:

(B, T-1) inferred latent action indices

Return type:

latent_actions

generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#

Generate video frames given a prompt frame and actions.

Parameters:
  • prompt_frame (Tensor) – (B, C, H, W) initial frame

  • num_frames (int) – Total number of frames to generate

  • actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random

  • use_maskgit (bool) – Whether to use MaskGIT sampling

Returns:

(B, C, num_frames, H, W)

Return type:

generated_video

play(current_frame, action, current_frames=None)[source]#

Play step - generate next frame given current frame and action.

Parameters:
  • current_frame (Tensor) – (B, C, H, W) current frame

  • action (Tensor) – (B,) latent action indices

  • current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame

Returns:

(B, C, H, W)

Return type:

next_frame

get_num_parameters()[source]#

Return total number of parameters.

Return type:

int

class world_models.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Latent Action Model (LAM) for unsupervised action learning.

Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.

Based on Genie paper - learns actions without action labels from Internet videos.

Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

encode(x_prev, x_next)[source]#

Encode frames to latent actions.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)

Return type:

latent_actions

decode(x_prev, z_q)[source]#

Decode latent actions to predict next frame.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W) - will mask all but first

  • z_q (Tensor) – Quantized action embeddings (B, T-1, embedding_dim)

Returns:

(B, C, H, W)

Return type:

predicted_next_frame

forward(x_prev, x_next)[source]#

Full forward pass: encode to get actions, decode to reconstruct.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Dictionary with losses and outputs

Return type:

Dict[str, Tensor]

class world_models.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=True)[source]#

Bases: Module

Dynamics Model for action-controllable video generation.

A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.

Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • gradient_checkpointing (bool)

forward(video_tokens, actions, mask_prob=0.0)[source]#

Forward pass for training.

Parameters:
  • video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T

  • actions (Tensor) – (B, T) - latent action indices for frames 1 to T

  • mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)

Returns:

(B, T, H*W, vocab_size)

Return type:

logits

sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#

Sample future frames using MaskGIT.

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • sampler (MaskGITSampler | None) – MaskGIT sampler instance

Returns:

(B, num_frames, N)

Return type:

generated_tokens

autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#

Simple autoregressive sampling (token by token).

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • temperature (float) – Sampling temperature

Returns:

(B, num_frames, N)

Return type:

generated_tokens

world_models.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Create a smaller Genie model for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#

Create the full 11B parameter Genie model (approximate).

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

LatentActionModel

world_models.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#

Factory function to create a Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

Return type:

DynamicsModel

class world_models.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer for latent dynamics learning.

The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:

  1. Deterministic State (h) – A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.

  2. Stochastic State (s) – A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).

The model operates in two modes:

  • Observe Mode – Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)

  • Imagine Mode – Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)

Architecture

  • Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1}

  • Process: GRU updates deterministic state, MLP computes stochastic prior/posterior

  • Output: Updated state (h_t, s_t) and distributions

State Representation

  • deter (h): GRU hidden state, captures sequential context

  • stoch (s): Stochastic latent, multi-modal uncertainty

  • mean/std: Parameters of the stochastic distribution

Usage with DreamerAgent:

rssm = RSSM(
    action_size=action_dim,
    stoch_size=30,      # Stochastic state dimension
    deter_size=200,     # Deterministic (GRU) state dimension
    hidden_size=200,    # MLP hidden layer size
    obs_embed_size=256,  # Observation embedding from encoder
    activation='elu'
)

# Observe with actual observation
posterior = rssm.observe_step(prev_state, prev_action, obs_embed)

# Imagine future without observation
prior = rssm.imagine_step(current_state, action)

Training

The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to

capture environment dynamics

  • Reconstruction loss from decoder ensures state captures observation info

Reference:

Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • hidden_size (int)

  • obs_embed_size (int)

  • activation (str)

init_state(batch_size, device)[source]#

Initialize RSSM state with zeros.

Parameters:
  • batch_size (int) – Number of parallel sequences

  • device (device) – torch device for tensors

Returns:

  • mean, std: Stochastic distribution parameters

  • stoch: Stochastic state sample

  • deter: Deterministic GRU hidden state

Return type:

Dictionary containing zero-initialized state components

get_dist(mean, std)[source]#

Create an Independent Normal distribution from mean and std.

Parameters:
  • mean (Tensor) – Location parameter

  • std (Tensor) – Scale parameter

Returns:

Independent Normal distribution with given parameters

Return type:

Independent

observe_step(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#

Update state using actual observation (observe mode).

In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.

Parameters:
  • prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)

  • obs_embed (Tensor) – Observation embedding from encoder, shape (B, obs_embed_size)

  • nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

A tuple (posterior, prior) of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.

Return type:

Tuple[dict, dict]

imagine_step(prev_state, prev_action, nonterm=tensor(1.))[source]#

Predict next state without observation (imagine mode).

In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.

Parameters:
  • prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)

  • nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

  • deter: Predicted deterministic state

  • mean, std, stoch: Prior stochastic state distribution

Return type:

Dictionary with predicted state containing

get_prior(prev_state, prev_action, nonterm=tensor(1.))[source]#
Parameters:
  • prev_state (dict)

  • prev_action (Tensor)

  • nonterm (Tensor)

Return type:

dict

get_posterior(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#

Compute posterior distribution over stochastic state.

The posterior incorporates observation information to produce a more accurate state estimate.

Parameters:
  • prev_state (dict) – Previous state dictionary

  • prev_action (Tensor) – Previous action

  • obs_embed (Tensor) – Observation embedding

  • nonterm (Tensor) – Termination mask

Returns:

Dictionary with posterior state (observation-informed). Note that the previous-state shape (B, ...) is preserved; the batch dimension is not flattened.

Return type:

dict

detach_state(state)[source]#

Detach state tensors from computation graph.

Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.

Parameters:

state (dict) – State dictionary with tensor values

Returns:

Detached state dictionary

Return type:

dict

seq_to_batch(state_dict)[source]#

Convert sequence state to batch format.

Parameters:

state_dict (dict) – Dictionary with sequence-dimension tensors (T, B, …)

Returns:

Dictionary with batch-dimension tensors (B*T, …)

Return type:

dict

observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#

Process a sequence of observations (observe mode rollout).

At each timestep we run observe_step once to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.

Parameters:
  • obs_embed (Tensor) – Observation embeddings, shape (T+1, B, obs_embed_size)

  • actions (Tensor) – Actions, shape (T, B, action_size)

  • nonterms (Tensor) – Non-termination flags, shape (T, B, 1)

  • init_state (dict) – Initial state dictionary

  • seq_len (int) – Sequence length T

Returns:

Dictionary with prior states stacked along the time axis posterior: Dictionary with posterior states stacked along the time axis

Return type:

prior

imagine_rollout(policy, init_state, horizon)[source]#

Generate imagined trajectory using policy (imagine mode rollout).

Parameters:
  • policy (Module) – Actor network that outputs actions from state features

  • init_state (dict) – Initial state dictionary

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined states for each step

Return type:

dict

forward(x, u)[source]#

Forward pass for training (computes sequence of states).

Parameters:
  • x (Tensor) – Observations, shape (B, T+1, C, H, W)

  • u (Tensor) – Actions, shape (B, T, action_size)

Returns:

List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)

Return type:

states

class world_models.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#

Bases: Module

A Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.

Parameters:
  • action_size (int)

  • state_size (int)

  • latent_size (int)

  • hidden_size (int)

  • embed_size (int)

  • activation_function (str)

get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#

Returns the initial posterior given the observation.

Parameters:
  • enc (Tensor)

  • h_t (Tensor | None)

  • s_t (Tensor | None)

  • a_t (Tensor | None)

  • mean (bool)

Return type:

tuple[Tensor, Tensor]

deterministic_state_fwd(h_t, s_t, a_t)[source]#

Deterministic transition update.

Ensures a_t is 2D and matches batch dimension of h_t before concatenation. Accepts a_t shaped [B, action_size], [action_size] (expanded to [B, action_size]), or [B]/scalar (reshaped appropriately).

Parameters:
  • h_t (Tensor)

  • s_t (Tensor)

  • a_t (Tensor)

Return type:

Tensor

state_prior(h_t, sample=False)[source]#

Returns the prior distribution over the latent state given the deterministic state

Parameters:
  • h_t (Tensor)

  • sample (bool)

Return type:

tuple[Tensor, Tensor] | Tensor

state_posterior(h_t, e_t, sample=False)[source]#

Returns the state prior given the deterministic state and obs

Parameters:
  • h_t (Tensor)

  • e_t (Tensor)

  • sample (bool)

Return type:

tuple[Tensor, Tensor] | Tensor

pred_reward(h_t, s_t)[source]#
Parameters:
  • h_t (Tensor)

  • s_t (Tensor)

Return type:

Tensor

rollout_prior(act, h_t, s_t)[source]#
Parameters:
  • act (Tensor)

  • h_t (Tensor)

  • s_t (Tensor)

Return type:

tuple[Tensor, Tensor]

forward(x, u)[source]#

Forward through the RSSM for a batch of sequences.

Parameters:
  • x (Tensor) – Tensor [B, T+1, C, H, W] (observations including initial frame)

  • u (Tensor) – Tensor [B, T, action_size] (actions for T steps)

Returns:

list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]

Return type:

states

class world_models.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).

  • Input: (B, C, H, W) raw images, values in [-0.5, 0.5]

  • Process: 4 convolutional layers with stride 2, halving spatial dimensions

  • Output: (B, embed_size) compact representation

The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.

Usage with Dreamer:

encoder = ConvEncoder(
    input_shape=(3, 64, 64),  # RGB 64x64 images
    embed_size=256,           # RSSM observation embedding size
    activation='relu'         # Activation function
)
obs_embedding = encoder(observation)  # (B, 256)
Parameters:
  • input_shape (tuple) – Tuple (C, H, W) for input images, typically (3, 64, 64)

  • embed_size (int) – Output embedding dimension, typically 256 or 1024

  • activation (str) – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)

  • depth (int) – Base channel depth for first layer (default 32)

forward(inputs)[source]#
Parameters:

inputs (Tensor)

Return type:

Tensor

class world_models.CNNEncoder(embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) encoder for processing image inputs.

Parameters:
  • embedding_size (int)

  • activation_function (str)

forward(observation)[source]#
Parameters:

observation (Tensor)

Return type:

Tensor

class world_models.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional decoder for reconstructing observations from latent states.

Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.

  • Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter)

  • Process: Dense projection + 4 transposed convolutions (upsampling 2x each)

  • Output: Independent Normal distribution over observation pixels

The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.

Returns torch.distributions.Independent(Normal(mean, std), len(shape)) allowing log_prob(observation) computation for reconstruction loss.

Usage in Dreamer world model:

decoder = ConvDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(3, 64, 64),  # RGB images
    activation='relu'
)
obs_dist = decoder(latent_features)  # Returns distribution
log_prob = obs_dist.log_prob(target_observation)

The reconstruction loss is -log_prob(observation), which encourages the RSSM to learn states that capture observation information.

Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (tuple[int, ...])

  • activation (str)

  • depth (int)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Independent

class world_models.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) decoder for reconstructing image outputs.

Parameters:
  • state_size (int)

  • latent_size (int)

  • embedding_size (int)

  • activation_function (str)

forward(latent, state)[source]#
Parameters:
  • latent (Tensor)

  • state (Tensor)

Return type:

Tensor

class world_models.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.

  • Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter)

  • Process: MLP with configurable layers and hidden units

  • Output: Predicted quantity with distribution (normal, binary, or raw)

Supports three output types: - 'normal': Gaussian distribution for regression (rewards, values) - 'binary': Bernoulli distribution for binary classification (discount) - 'none': Raw tensor for non-probabilistic outputs

Usage:

reward_decoder = DenseDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(1,),
    n_layers=2,
    units=400,
    activation='elu',
    dist='normal'
)
reward_dist = reward_decoder(latent_features)
reward_loss = -reward_dist.log_prob(target_reward)

For discount prediction (binary):

discount_decoder = DenseDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(1,),
    n_layers=2,
    units=400,
    activation='elu',
    dist='binary'
)
Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (tuple[int, ...])

  • n_layers (int)

  • units (int)

  • activation (str)

  • dist (str)

  • num_buckets (int)

  • symlog_range (float)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Any

class world_models.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • n_layers (int)

  • units (int)

  • activation (str)

  • min_std (float)

  • init_std (float)

  • mean_scale (float)

forward(features, deter=False)[source]#
Parameters:
  • features (Tensor)

  • deter (bool)

Return type:

Tensor

add_exploration(action, action_noise=0.3)[source]#
Parameters:
  • action (Tensor)

  • action_noise (float)

Return type:

Tensor

class world_models.TanhBijector[source]#

Bases: Transform

Bijective tanh transform for squashing Gaussian distributions to [-1, 1].

This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:

  1. Bijective mapping: tanh is invertible (with atanh as inverse)

  2. Stable log-det Jacobian: Computable for gradient-based training

  3. Clipped actions: During inference, actions are naturally bounded

  • Forward: y = tanh(x)

  • Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y))

  • Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))

Usage with Dreamer ActionDecoder:

dist = TransformedDistribution(
    Normal(mean, std),
    TanhBijector()
)
action = dist.sample()  # Bounded to [-1, 1]
Reference:

Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.

property sign: int#
atanh(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

log_abs_det_jacobian(x, y)[source]#
Parameters:
  • x (Tensor)

  • y (Tensor)

Return type:

Tensor

class world_models.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

Parameters:
  • dist (Any)

  • samples (int)

property name: str#
mean()[source]#
Return type:

Tensor

mode()[source]#
Return type:

Tensor

entropy()[source]#
Return type:

Tensor

sample()[source]#
Return type:

Tensor

class world_models.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Encoder for IRIS discrete autoencoder.

Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.

Architecture:
  • 4 convolutional layers with residual blocks

  • Self-attention at 8x8 and 16x16 resolutions

  • Vector quantization to produce discrete tokens

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • in_channels (int)

  • base_channels (int)

  • num_residual_blocks (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Encode images to discrete tokens.

Parameters:

x (Tensor) – Input images (B, C, H, W) - should be 64x64

Returns:

Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

encode_to_indices(x)[source]#

Encode directly to token indices (for world model).

Parameters:

x (Tensor)

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode token indices to embeddings (for decoder).

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Decoder for IRIS discrete autoencoder.

Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • base_channels (int)

  • out_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(z)[source]#

Decode tokens to images.

Parameters:

z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)

Returns:

Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)

Return type:

reconstructed

decode_from_embeddings(z_flat)[source]#

Decode flattened token embeddings to images.

Parameters:

z_flat (Tensor) – Flattened tokens (B, H*W, C) or (B, C, H, W)

Returns:

Reconstructed images

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode discrete token indices into images.

Parameters:

indices (Tensor) – Tensor of shape (B, H, W) or (B, H*W) containing integer token indices in the range [0, vocab_size).

Returns:

Reconstructed images (B, C, H, W)

Return type:

Tensor

class world_models.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#

Bases: Module

Video Tokenizer using VQ-VAE with Spatiotemporal Transformer.

This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.

The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.

Architecture

  1. Patch Embedding: Convert (B, C, T, H, W) video to patch tokens

  2. Encoder ST-Transformer: Process spatial-temporal patches

  3. Vector Quantization: Discretize continuous embeddings to codebook entries

  4. Decoder ST-Transformer: Reconstruct video from quantized tokens

  5. Patch Unembedding: Convert tokens back to video frames

Key Features

  • Causal processing: Each frame’s encoding only uses previous frames

  • Discrete tokens: Enables autoregressive prediction with latent actions

  • Memory efficient: Uses ST-Transformer instead of full ViT to reduce complexity

Usage with Genie:

tokenizer = VideoTokenizer(
    num_frames=16,
    image_size=64,
    patch_size=4,
    vocab_size=1024,
    embedding_dim=32
)
reconstructed, indices, loss_dict = tokenizer(video_frames)

# For discrete token input to dynamics model:
token_embeddings = tokenizer.decode_indices(indices)

The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings - Commitment loss: Penalizes encoder outputs drifting from codebook

Reference:

Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • use_ema (bool)

  • ema_decay (float)

encode(x)[source]#

Encode video to discrete tokens.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices to embeddings for video frames.

Parameters:

indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’ x W’

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim)

Return type:

z_q

decode(z_q)[source]#

Decode discrete tokens to video frames.

Parameters:

z_q (Tensor) – Quantized embeddings (B, T, H’, W’, embedding_dim)

Returns:

Reconstructed video (B, C, T, H, W)

Return type:

Tensor

forward(x)[source]#

Full forward pass with VQ-VAE objective.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Reconstructed video (B, C, T, H, W) indices: Token indices (B, T, H’, W’) loss_dict: Dictionary containing loss components

Return type:

reconstructed

world_models.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#

Factory function to create a Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

Return type:

VideoTokenizer

class world_models.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#

Bases: Module

Vector Quantizer for discrete autoencoder.

Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)

Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

forward(z)[source]#

Quantize the input latents.

Parameters:

z (Tensor) – Input tensor of shape (B, C, H, W) or (B, C)

Returns:

Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices back to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#

Bases: Module

Vector Quantizer with Exponential Moving Average updates.

Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • ema_decay (float)

  • epsilon (float)

forward(z)[source]#

Quantize with EMA updates.

Parameters:

z (Tensor)

Return type:

Tuple[Tensor, Tensor, dict[str, Tensor]]

decode_indices(indices)[source]#

Decode token indices to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer with image observations and transitions.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.

Key Features

  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores raw uint8 images to save memory

  • Samples sequences (not single transitions) for temporal modeling

  • Validates sampled sequences don’t span episode boundaries

Memory Layout

  • observations: (capacity, C, H, W) uint8 images

  • actions: (capacity, action_dim) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)

Sampling Process

  1. Random start index (avoiding episode boundaries)

  2. Collect sequence of length seq_len with wraparound

  3. Validate no terminal in middle of sequence

  4. Return batch of sequences

Usage with Dreamer:

buffer = ReplayBuffer(
    size=100000,           # Max transitions to store
    obs_shape=(3, 64, 64), # RGB images
    action_size=6,         # Continuous action dim
    seq_len=50,            # Sequence length for training
    batch_size=50          # Parallel sequences per batch
)

# Add transitions during interaction
buffer.add(obs, action, reward, done)

# Sample batch for world model training
obs_batch, action_batch, reward_batch, term_batch = buffer.sample()

Memory Efficiency

  • Uses uint8 for images (1 byte per pixel vs 4 for float32)

  • Sequences share observations (overlapping windows)

  • Configurable capacity based on available system memory

Note

The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.

Parameters:
  • size (int)

  • obs_shape (Tuple[int, ...])

  • action_size (int)

  • seq_len (int)

  • batch_size (int)

add(obs, ac, rew, done)[source]#

Add a transition to the buffer.

Parameters:
  • obs (dict) – Observation dict with ‘image’ key containing the observation

  • ac (ndarray) – Action taken, shape (action_size,)

  • rew (float) – Reward received, scalar

  • done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise

Return type:

None

sample()[source]#

Sample a batch of sequences for training.

Returns:

(observations, actions, rewards, terminals)
  • observations: (seq_len, batch, C, H, W)

  • actions: (seq_len, batch, action_dim)

  • rewards: (seq_len, batch)

  • terminals: (seq_len, batch)

Return type:

tuple

class world_models.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.

  • Stores complete episodes as lists of transitions

  • Samples contiguous sub-sequences for sequence models

  • Supports time-major formatting (time-first) for RNN input

  • Memory usage estimation to prevent OOM errors

Parameters:

size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).

episodes#

Collection of Episode objects.

Type:

deque

eps_lengths#

Length of each episode.

Type:

deque

size#

Total number of transitions across all episodes.

Type:

property

Example:

memory = Memory(size=100)
memory.append([episode1, episode2])
batch, lengths = memory.sample(batch_size=32, tracelen=50)
property size: int#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

Return type:

None

sample(batch_size, tracelen=1, time_first=False)[source]#

Sample random sub-sequences from stored episodes.

Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.

Parameters:
  • batch_size (int) – Number of sequences to sample.

  • tracelen (int) – Length of each sequence (default: 1).

  • time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).

Returns:

(observations, actions, rewards, terminals, lengths)
  • observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)

  • actions: (batch, tracelen, action_dim) or (tracelen, batch, …)

  • rewards: (batch, tracelen) or (tracelen, batch)

  • terminals: (batch, tracelen) or (tracelen, batch)

  • lengths: (batch,) original episode lengths for each sample

Return type:

tuple

Raises:
  • ValueError – If memory is empty or no episodes meet minimum length.

  • MemoryError – If estimated memory usage exceeds 200 MiB threshold.

class world_models.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode.

Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.

x#

Observations collected during the episode.

Type:

list or np.ndarray

u#

Actions taken.

Type:

list or np.ndarray

r#

Rewards received.

Type:

list or np.ndarray

t#

Terminal flags (0.0 = continue, 1.0 = terminal).

Type:

list or np.ndarray

info#

Additional episode metadata.

Type:

dict

Parameters:

postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.

Example:

episode = Episode()
episode.append(obs, action, reward, False)
episode.append(obs, action, reward, True)
episode.terminate(final_obs)
print(episode.x.shape)  # Now a numpy array
property size: int#
append(obs, act, reward, terminal)[source]#
Parameters:
  • obs (Any)

  • act (Any)

  • reward (Any)

  • terminal (Any)

Return type:

None

terminate(obs)[source]#
Parameters:

obs (Any)

Return type:

None

class world_models.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#

Bases: object

Replay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.

Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores uint8 images for memory efficiency

  • Samples sequences with validation to avoid episode boundaries

  • Supports sequence sampling for temporal learning

Memory Layout:
  • observations: (capacity, C, H, W) uint8

  • actions: (capacity, action_size) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32

Parameters:
  • size (int) – Maximum number of transitions to store.

  • obs_shape (tuple) – Shape of observations as (C, H, W).

  • action_size (int) – Dimension of actions.

  • seq_len (int) – Length of sequences to sample (default: 20).

  • batch_size (int) – Number of sequences per batch (default: 64).

size#

Buffer capacity.

Type:

int

obs_shape#

Observation shape.

Type:

tuple

action_size#

Action dimension.

Type:

int

seq_len#

Sequence length.

Type:

int

batch_size#

Batch size.

Type:

int

steps#

Total transitions added.

Type:

int

episodes#

Number of episode terminations observed.

Type:

int

add(obs, action, reward, terminal)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray) – Observation array with shape (C, H, W).

  • action (ndarray) – Action array with shape (action_size,).

  • reward (float) – Scalar reward value.

  • terminal (bool) – Boolean indicating if episode terminated.

Return type:

None

sample_sequence(seq_len=None)[source]#

Sample a batch of sequences for world model training.

Returns:

(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)

Return type:

observations

Parameters:

seq_len (int | None)

sample_single()[source]#

Sample a single transition for online updates.

Return type:

Tuple[ndarray, ndarray, float, float]

property buffer_capacity: int#

Returns the total capacity of the buffer.

class world_models.IRISOnPolicyBuffer(max_steps=1000)[source]#

Bases: object

On-policy buffer for collecting trajectories during environment interaction.

Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.

Useful for:
  • Collecting complete episode trajectories

  • Storing data before batch processing

  • Temporary storage during environment interaction

Parameters:

max_steps (int) – Maximum number of steps to store (default: 1000).

max_steps#

Maximum buffer capacity.

Type:

int

observations#

List of observations.

Type:

list

actions#

List of actions.

Type:

list

rewards#

List of rewards.

Type:

list

terminals#

List of terminal flags.

Type:

list

add(obs, action, reward, terminal)[source]#
Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

Return type:

None

clear()[source]#
Return type:

None

get_arrays()[source]#
Return type:

tuple[ndarray, ndarray, ndarray, ndarray]

class world_models.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

Parameters:
  • img_size (int)

  • patch_size (int)

  • in_channels (int)

  • d_model (int)

  • depth (int)

  • heads (int)

  • drop (float)

  • t_dim (int)

forward(x, t)[source]#
Parameters:
  • x (Tensor)

  • t (Tensor)

Return type:

Tensor

classmethod from_config(config=None, **overrides)[source]#

Build DiT from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (DiTConfig | dict[str, Any] | str | Path | None)

  • overrides (Any)

Return type:

DiT

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Load DiT weights from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (DiTConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

DiT

save_pretrained(path)[source]#

Save DiT weights and config in a from_pretrained-compatible format.

Parameters:

path (str | Path)

Return type:

None

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
Parameters:
  • epochs (int)

  • dataset (Any)

  • batch_size (int)

  • lr (float)

  • img_size (int)

  • channels (int)

  • patch (int)

  • width (int)

  • depth (int)

  • heads (int)

  • drop (float)

  • timesteps (int)

  • beta_start (float)

  • beta_end (float)

  • ema (bool)

  • ema_decay (float)

  • workdir (str)

  • root_path (str)

  • image_folder (str | None)

  • crop_size (int)

  • download (bool)

  • copy_data (bool)

  • subset_file (str | None)

  • val_split (float | None)

Return type:

None

world_models.create_dit(config=None, **overrides)[source]#

Create a DiT from a DiTConfig or keyword overrides.

The public factory API works with config objects, while DiT itself has a compact constructor. This adapter keeps the lower-level model constructor unchanged and maps the public config fields onto the expected arguments.

Parameters:
  • config (Any)

  • overrides (Any)

Return type:

DiT

class world_models.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.

Process:
  1. Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches

  2. Each patch is projected to embed_dim via linear layer (Conv2d)

  3. Learnable positional embeddings are added for spatial information

Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)

Parameters:
  • img_size (int) – Image size (assumes square), e.g., 32 for CIFAR

  • patch_size (int) – Size of each patch (typically 4, 8, or 16)

  • in_channels (int) – Number of input channels (3 for RGB)

  • embed_dim (int) – Output dimension for each patch token

Usage with DiT:

patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

Parameters:
  • img_size (int)

  • patch_size (int)

  • embed_dim (int)

  • out_channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.DDPM(timesteps, beta_start, beta_end)[source]#

Bases: Module

Utility module implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

Parameters:
  • timesteps (int)

  • beta_start (float)

  • beta_end (float)

q_sample(x_start, t, noise=None)[source]#
Parameters:
  • x_start (Tensor)

  • t (Tensor)

  • noise (Tensor | None)

Return type:

Tensor

p_sample(model, x_t, t)[source]#
Parameters:
  • model (Module)

  • x_t (Tensor)

  • t (Tensor)

Return type:

Tensor

sample(model, n, img_size, channels)[source]#
Parameters:
  • model (Module)

  • n (int)

  • img_size (int)

  • channels (int)

Return type:

Tensor

class world_models.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#

Bases: Module

Actor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.

Parameters:
  • obs_channels (int)

  • action_dim (int)

  • channels (Tuple[int, ...])

  • lstm_dim (int)

forward(obs, hidden_state=None)[source]#

Forward pass of actor-critic network.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)

Return type:

policy_logits

get_action(obs, hidden_state=None, deterministic=False)[source]#

Get action from a single observation.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

  • deterministic (bool) – If True, take argmax; else sample

Returns:

Selected action [B] hidden_state: (h, c)

Return type:

action

get_actions(obs, hidden_state=None, deterministic=False)[source]#

Batched version of get_action.

Parameters:
  • obs (Tensor) – Tensor of shape [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size

  • deterministic (bool) – If True, take argmax; else sample from policy

Returns:

LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple

Return type:

actions

get_value(obs, hidden_state=None)[source]#

Get value for a single observation.

Parameters:
  • obs (Tensor)

  • hidden_state (Tuple[Tensor, Tensor] | None)

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor] | None]

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_hidden_size()[source]#

Get LSTM hidden size.

Return type:

int

class world_models.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#

Bases: Module

Reward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.

Parameters:
  • obs_channels (int) – Number of observation channels (3 for RGB)

  • action_dim (int) – Number of possible actions

  • channels (Tuple[int, ...]) – List of channel sizes for conv blocks

  • lstm_dim (int) – LSTM hidden dimension

  • cond_dim (int) – Conditioning dimension for adaptive norm

forward(obs, actions, hidden_state=None)[source]#

Forward pass of reward/termination model.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • actions (Tensor) – Actions [B, T]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states

Return type:

reward_logits

predict(obs, actions, hidden_state=None)[source]#

Predict reward and termination for a single step.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • actions (Tensor) – Single action [B]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states

Return type:

reward

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

world_models.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.

Math:

embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)

Parameters:
  • timesteps (Tensor) – Tensor of integer timesteps, shape (B,) or (B, 1)

  • dim (int) – Embedding dimension (must be even)

Returns:

Tensor of shape (B, dim) with sinusoidal embeddings

Return type:

Tensor

Usage with DiT:

t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)

# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization

class world_models.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, gradient_checkpointing=False)[source]#

Bases: Module

Spatiotemporal Transformer for video modeling.

Contains L spatiotemporal blocks with interleaved spatial and temporal attention.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

  • gradient_checkpointing (bool)

forward(x)[source]#
Parameters:

x (Tensor) – (B, T*N, C) where T is num_frames, N is num_patches_per_frame

Returns:

(B, T*N, C)

Return type:

Tensor

class world_models.MultiHeadSelfAttention(d, n_heads=2)[source]#

Bases: Module

Multi-head scaled dot-product self-attention over sequence tokens.

This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.

Parameters:
  • d (int)

  • n_heads (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

world_models.MultiHeadAttention#

alias of MultiHeadSelfAttention

class world_models.AdaLNNormalization(d_model, t_dim)[source]#

Bases: Module

Adaptive layer normalization conditioned on an external embedding.

The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).

Parameters:
  • d_model (int)

  • t_dim (int)

forward(x, t_emb)[source]#
Parameters:
  • x (Tensor)

  • t_emb (Tensor)

Return type:

Tensor

class world_models.RMSNorm(dim, eps=1e-06)[source]#

Bases: Module

Root Mean Square Layer Normalization with a learned gain parameter.

RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.

Parameters:
  • dim (int)

  • eps (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller using Cross-Entropy Method (CEM) with RSSM.

Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.

The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.

Algorithm:
  1. Initialize Gaussian distribution over action sequences

  2. Sample N candidate action sequences

  3. Rollout each sequence in RSSM latent space

  4. Score by predicted cumulative rewards

  5. Keep top K candidates, fit Gaussian to them

  6. Repeat for T iterations

  7. Execute first action from best sequence

Parameters:
  • model (Any)

  • planning_horizon (int)

  • num_candidates (int)

  • num_iterations (int)

  • top_candidates (int)

  • device (device | str)

rssm#

The RSSM world model.

N#

Number of candidate action sequences to sample.

K#

Number of top candidates to use for updating the proposal.

T#

Number of CEM iterations per planning step.

H#

Planning horizon (number of future steps to consider).

d#

Action dimensionality.

device#

Device to run computations on.

state_size#

Hidden state dimensionality.

latent_size#

Latent state dimensionality.

Example

>>> policy = RSSMPolicy(
...     model=rssm,
...     planning_horizon=12,
...     num_candidates=1000,
...     num_iterations=5,
...     top_candidates=100,
...     device='cuda'
... )
>>> policy.reset()
>>> action = policy.poll(observation)
reset()[source]#

Reset the policy state.

Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.

Return type:

None

poll(observation, explore=False)[source]#

Get action for given observation.

Parameters:
  • observation (Tensor) – Current observation tensor of shape (channels, height, width).

  • explore (bool) – If True, add exploration noise to the selected action.

Returns:

Action tensor of shape (1, action_size).

Return type:

Tensor

class world_models.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Actor network for IRIS (Imagined Rollouts with Implicit Successor) policy.

Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.

Architecture:
  • CNN: Extracts features from input frames (3x64x64 -> 512)

  • LSTM: Processes temporal sequences with configurable layers

  • Linear: Maps hidden states to action logits

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

action_size#

Number of discrete actions.

Type:

int

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

forward(frames, hidden_state=None, burn_in_frames=None)[source]#

Forward pass through actor.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state

  • burn_in_frames (Tensor | None) – Frames to use for initializing hidden state

Returns:

Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple

Return type:

action_logits

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_action(frame, temperature=1.0, deterministic=False)[source]#

Get action from a single frame.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • temperature (float) – Softmax temperature (higher = more random)

  • deterministic (bool) – If True, return argmax; else sample

Returns:

Selected action indices (B,)

Return type:

action

class world_models.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Critic network for IRIS value estimation.

Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.

Architecture:
  • CNN: Shared feature extractor with actor (3x64x64 -> 512)

  • LSTM: Temporal processing with same architecture as actor

  • Linear: Maps hidden states to scalar values

Parameters:
  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Returns:

Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.

Return type:

values

Parameters:
  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None)[source]#

Forward pass through critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple

Returns:

Value estimates (B, T) hidden_state: Updated (h, c) tuple

Return type:

values

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class world_models.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Combined policy module for IRIS (Imagined Rollouts with Implicit Successor).

Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

actor#

The actor network for action selection.

Type:

IRISActor

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Example

>>> policy = IRISPolicy(
...     action_size=18,
...     hidden_size=512,
...     num_layers=4,
...     frame_shape=(3, 64, 64)
... )
>>> action = policy.act(frame, temperature=1.0, deterministic=False)
forward(frames)[source]#

Get action logits from frames.

Parameters:

frames (Tensor)

Return type:

Tensor

act(frame, temperature=1.0, deterministic=False)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor)

  • temperature (float)

  • deterministic (bool)

Return type:

Tensor

init_hidden(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

tuple[Tensor, Tensor]

class world_models.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#

Bases: Module

CNN feature extractor shared between actor and critic networks.

Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.

Architecture:
  • Conv layers: 32 -> 64 -> 128 -> 256 channels

  • Each conv has stride=2 for spatial downsampling

  • Final linear layer projects to desired output dimension

Parameters:
  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

  • output_size (int) – Size of output feature vector (default: 512).

frame_shape#

Input frame shape.

Type:

tuple

output_size#

Output feature dimension.

Type:

int

Returns:

Feature vectors with shape (B, output_size).

Return type:

features

Parameters:
  • frame_shape (Tuple[int, int, int])

  • output_size (int)

forward(x)[source]#

Extract features from frames.

Parameters:

x (Tensor) – Frames (B, C, H, W)

Returns:

Feature vectors (B, output_size)

Return type:

features

class world_models.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#

Bases: SerializableConfigMixin

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

Parameters:
  • env_backend (str)

  • env (str)

  • env_instance (Any)

  • image_size (tuple[int, int])

  • gym_render_mode (str)

  • dmlab_action_repeat (int)

  • dmlab_action_set (Any)

  • dmlab_observations (Any)

  • dmlab_config (Any)

  • dmlab_renderer (str)

  • procgen_distribution_mode (str)

  • procgen_num_levels (int)

  • procgen_start_level (Any)

  • mujoco_xml_path (Any)

  • mujoco_xml_string (Any)

  • mujoco_binary_path (Any)

  • mujoco_camera (Any)

  • mujoco_frame_skip (int)

  • mujoco_reset_noise_scale (float)

  • brax_backend (str)

  • brax_jit (bool)

  • brax_auto_reset (bool)

  • brax_suppress_warp_warnings (bool)

  • unity_file_name (Any)

  • unity_behavior_name (Any)

  • unity_worker_id (int)

  • unity_base_port (int)

  • unity_no_graphics (bool)

  • unity_time_scale (float)

  • unity_quality_level (int)

  • algo (str)

  • exp_name (str)

  • train (bool)

  • evaluate (bool)

  • seed (int)

  • no_gpu (bool)

  • max_episode_length (int)

  • buffer_size (int)

  • time_limit (int)

  • cnn_activation_function (str)

  • dense_activation_function (str)

  • obs_embed_size (int)

  • num_units (int)

  • deter_size (int)

  • stoch_size (int)

  • action_repeat (int)

  • action_noise (float)

  • total_steps (int)

  • seed_steps (int)

  • update_steps (int)

  • collect_steps (int)

  • batch_size (int)

  • train_seq_len (int)

  • imagine_horizon (int)

  • use_disc_model (bool)

  • free_nats (float)

  • discount (float)

  • td_lambda (float)

  • kl_loss_coeff (float)

  • kl_alpha (float)

  • disc_loss_coeff (float)

  • num_buckets (int)

  • symlog_range (float)

  • model_learning_rate (float)

  • actor_learning_rate (float)

  • value_learning_rate (float)

  • adam_epsilon (float)

  • grad_clip_norm (float)

  • use_amp (bool)

  • test (bool)

  • test_interval (int)

  • test_episodes (int)

  • scalar_freq (int)

  • log_video_freq (int)

  • max_videos_to_save (int)

  • video_format (str)

  • video_fps (int)

  • checkpoint_interval (int)

  • checkpoint_path (str)

  • restore (bool)

  • experience_replay (str)

  • render (bool)

  • enable_wandb (bool)

  • wandb_api_key (str)

  • wandb_project (str)

  • wandb_entity (str)

  • log_dir (str)

  • logdir (Any)

  • data_dir (Any)

  • log_level (str)

  • log_file (Any)

  • enable_tensorboard (bool)

  • enable_console_metrics (bool)

  • enable_jsonl (bool)

  • jsonl_filename (str)

  • log_system_stats_freq (int)

  • detect_anomaly (bool)

env_backend: str = 'dmc'#
env: str = 'walker-walk'#
env_instance: Any = None#
image_size: tuple[int, int] = (64, 64)#
gym_render_mode: str = 'rgb_array'#
dmlab_action_repeat: int = 4#
dmlab_action_set: Any = None#
dmlab_observations: Any = None#
dmlab_config: Any = None#
dmlab_renderer: str = 'hardware'#
procgen_distribution_mode: str = 'easy'#
procgen_num_levels: int = 0#
procgen_start_level: Any = None#
mujoco_xml_path: Any = None#
mujoco_xml_string: Any = None#
mujoco_binary_path: Any = None#
mujoco_camera: Any = None#
mujoco_frame_skip: int = 1#
mujoco_reset_noise_scale: float = 0.0#
brax_backend: str = 'generalized'#
brax_jit: bool = True#
brax_auto_reset: bool = False#
brax_suppress_warp_warnings: bool = True#
unity_file_name: Any = None#
unity_behavior_name: Any = None#
unity_worker_id: int = 0#
unity_base_port: int = 5005#
unity_no_graphics: bool = True#
unity_time_scale: float = 20.0#
unity_quality_level: int = 1#
algo: str = 'Dreamerv1'#
exp_name: str = 'lr1e-3'#
train: bool = True#
evaluate: bool = False#
seed: int = 1#
no_gpu: bool = False#
max_episode_length: int = 1000#
buffer_size: int = 800000#
time_limit: int = 1000#
cnn_activation_function: str = 'relu'#
dense_activation_function: str = 'elu'#
obs_embed_size: int = 1024#
num_units: int = 400#
deter_size: int = 200#
stoch_size: int = 30#
action_repeat: int = 2#
action_noise: float = 0.3#
total_steps: int = 5000000#
seed_steps: int = 5000#
update_steps: int = 100#
collect_steps: int = 1000#
batch_size: int = 50#
train_seq_len: int = 50#
imagine_horizon: int = 15#
use_disc_model: bool = False#
free_nats: float = 3.0#
discount: float = 0.99#
td_lambda: float = 0.95#
kl_loss_coeff: float = 1.0#
kl_alpha: float = 0.8#
disc_loss_coeff: float = 10.0#
num_buckets: int = 255#
symlog_range: float = 10.0#
model_learning_rate: float = 0.0006#
actor_learning_rate: float = 8e-05#
value_learning_rate: float = 8e-05#
adam_epsilon: float = 1e-07#
grad_clip_norm: float = 100.0#
use_amp: bool = True#
test: bool = False#
test_interval: int = 10000#
test_episodes: int = 10#
scalar_freq: int = 1000#
log_video_freq: int = -1#
max_videos_to_save: int = 2#
video_format: str = 'gif'#
video_fps: int = 20#
checkpoint_interval: int = 10000#
checkpoint_path: str = ''#
restore: bool = False#
experience_replay: str = ''#
render: bool = False#
enable_wandb: bool = False#
wandb_api_key: str = ''#
wandb_project: str = 'torchwm'#
wandb_entity: str = ''#
log_dir: str = 'runs'#
logdir: Any = None#
data_dir: Any = None#
log_level: str = 'INFO'#
log_file: Any = None#
enable_tensorboard: bool = False#
enable_console_metrics: bool = True#
enable_jsonl: bool = True#
jsonl_filename: str = 'metrics.jsonl'#
log_system_stats_freq: int = 1000#
detect_anomaly: bool = False#
class world_models.JEPAConfig[source]#

Bases: SerializableConfigMixin

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

classmethod from_dict(values)[source]#

Load flat field values or the nested trainer dictionary.

Parameters:

values (Dict[str, Any])

Return type:

JEPAConfig

to_train_dict()[source]#

Return the nested dictionary expected by train_jepa.

Return type:

Dict[str, Dict[str, Any]]

to_nested_dict()[source]#

Backward-compatible alias for the nested JEPA trainer dictionary.

Return type:

Dict[str, Dict[str, Any]]

class world_models.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: SerializableConfigMixin

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via __getattr__ and get_dit_config().

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Both UPPER_CASE and snake_case override keys are accepted.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)

Parameters:

overrides (Any)

Return type:

DiTConfig

class world_models.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: SerializableConfigMixin

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • data_loader_num_workers (int)

  • pin_memory (bool)

  • persistent_workers (bool)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • use_amp (bool)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
data_loader_num_workers: int = 4#
pin_memory: bool = True#
persistent_workers: bool = True#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
use_amp: bool = True#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#
class world_models.IRISConfig[source]#

Bases: SerializableConfigMixin

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
Return type:

tuple

get_autoencoder_config()[source]#
Return type:

dict

get_transformer_config()[source]#
Return type:

dict

get_rl_config()[source]#
Return type:

dict

class world_models.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: SerializableConfigMixin

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class world_models.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: SerializableConfigMixin

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class world_models.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.OperatorABC(*, device=None)[source]#

Bases: Module, ABC

Structured base class for inference operators.

Operators use a consistent pipeline:

  1. preprocess converts raw inputs into tensors.

  2. forward performs model/operator-specific tensor computation.

  3. postprocess formats the final output mapping.

Subclasses may also declare input_specs and output_specs to validate required tensor keys, shapes, and dtypes. OperatorABC inherits from torch.nn.Module, so operators support to(device), train(), and eval() just like model modules.

Parameters:

device (torch.device | str | None)

input_specs: Mapping[str, TensorSpec] = {}#
output_specs: Mapping[str, TensorSpec] = {}#
abstractmethod preprocess(inputs)[source]#

Convert raw inputs into a tensor mapping ready for forward.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

forward(inputs)[source]#

Run tensor computation for this operator.

Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.

Parameters:

inputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

postprocess(outputs)[source]#

Format validated forward outputs for consumers.

Parameters:

outputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

process(inputs)[source]#

Process raw inputs through preprocess, forward, and postprocess stages.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

batch(inputs)[source]#

Preprocess a sequence of inputs and stack matching tensor keys.

Parameters:

inputs (Sequence[Any])

Return type:

dict[str, Tensor]

to(*args, **kwargs)[source]#

Move module parameters/buffers and remember the target tensor device.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

OperatorABC

classmethod validate_mapping(values, specs, *, label)[source]#

Validate tensor keys, shapes, and dtypes against optional specs.

Parameters:
  • values (Mapping[str, Tensor])

  • specs (Mapping[str, TensorSpec])

  • label (str)

Return type:

None

class world_models.TensorSpec(shape=None, dtype=None, required=True)[source]#

Bases: object

Optional tensor contract used to validate operator inputs or outputs.

Parameters:
  • shape (tuple[int | None, ...] | None) – Expected shape. Use None as a wildcard for dimensions that may vary, such as batch size.

  • dtype (dtype | None) – Expected tensor dtype.

  • required (bool) – Whether the key must be present in the mapping being validated.

shape: tuple[int | None, ...] | None = None#
dtype: dtype | None = None#
required: bool = True#
class world_models.DreamerOperator(image_size=64, action_dim=6)[source]#

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

preprocess(inputs)[source]#

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

preprocess(inputs)[source]#

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.IrisOperator(seq_length=512, vocab_size=32000)[source]#

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

preprocess(inputs)[source]#

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.PlaNetOperator(state_dim=32, action_dim=4)[source]#

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

preprocess(inputs)[source]#

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

world_models.get_operator(name, **kwargs)[source]#

Factory function to get inference operators by name.

Parameters:
  • name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’

  • **kwargs (Any) – Operator-specific configuration

Returns:

Configured OperatorABC instance

Return type:

OperatorABC

Example

>>> op = get_operator('dreamer', image_size=64, action_dim=6)
>>> processed = op.process({'image': image, 'action': action})
class world_models.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Predict scalar rewards from Dreamer latent belief and state vectors.

Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.

Parameters:
  • belief_size (int)

  • state_size (int)

  • hidden_size (int)

  • activation_function (str)

forward(belief, state)[source]#
Parameters:
  • belief (Tensor)

  • state (Tensor)

Return type:

Tensor

class world_models.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#

Bases: Module

Estimate scalar value from Dreamer latent belief and state vectors.

This MLP is trained on imagined returns and used for actor/value updates.

Parameters:
  • belief_size (int)

  • state_size (int)

  • hidden_size (int)

  • activation_function (str)

forward(belief, state)[source]#
Parameters:
  • belief (Tensor)

  • state (Tensor)

Return type:

Tensor

world_models.DreamerRewardModel#

alias of RewardModel

world_models.DreamerValueModel#

alias of ValueModel

User-facing convenience APIs for TorchWM.

The lower-level modules remain available for research workflows, but this module collects the common discovery and construction paths behind small, predictable factory functions.

class world_models.api.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing an environment backend available through make_env.

Parameters:
  • name (str)

  • factory_path (str)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

factory_path: str#

Alias for field number 1

description: str#

Alias for field number 2

aliases: tuple[str, ...]#

Alias for field number 3

class world_models.api.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#

Bases: NamedTuple

Metadata describing a model available through create_model().

Parameters:
  • name (str)

  • import_path (str)

  • config_path (str | None)

  • description (str)

  • aliases (tuple[str, ...])

name: str#

Alias for field number 0

import_path: str#

Alias for field number 1

config_path: str | None#

Alias for field number 2

description: str#

Alias for field number 3

aliases: tuple[str, ...]#

Alias for field number 4

world_models.api.create_config(model, **overrides)[source]#

Create the default config object for model and apply overrides.

Examples

>>> cfg = create_config("dreamer", env="walker-walk", seed=7)
>>> cfg.env
'walker-walk'
Parameters:
  • model (str)

  • overrides (Any)

Return type:

Any

world_models.api.create_model(model, config=None, **overrides)[source]#

Instantiate a model or agent from a simple string name.

config is optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.

Examples

>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000)
>>> genie = create_model("genie-small", image_size=32)
Parameters:
  • model (str)

  • config (Any | None)

  • overrides (Any)

Return type:

Any

world_models.api.get_env_backend_spec(name)[source]#

Return metadata for an environment backend name or alias.

Parameters:

name (str)

Return type:

EnvBackendSpec

world_models.api.get_model_spec(name)[source]#

Return metadata for a model name or alias.

Parameters:

name (str)

Return type:

ModelSpec

world_models.api.list_env_backends()[source]#

Return canonical backend names accepted by make_env().

Return type:

list[str]

world_models.api.list_envs(model=None)[source]#

List known environment ids, optionally filtered by model family.

Parameters:

model (str | None)

Return type:

list[str] | dict[str, list[str]]

world_models.api.list_models()[source]#

Return canonical model names accepted by create_model().

Return type:

list[str]

world_models.api.make_env(env_id, backend='auto', **kwargs)[source]#

Create an environment with a consistent TorchWM entry point.

Parameters:
  • env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.

  • backend (str) – One of list_env_backends(); "auto" tries TorchWM’s compatibility helper.

  • **kwargs (Any) – Backend-specific options.

Return type:

Any

Export utilities for production deployment.

The public entry point is obj.export(path, format="onnx"). Importing this module installs that method on every torch.nn.Module once, so all TorchWM models get ONNX, TorchScript, and TensorRT export support without each model subclassing a TorchWM-specific base class. Non-nn.Module agent wrappers can inherit ExportableAgentMixin, which uses the same resolver and exporter.

class world_models.export.DreamerPolicyExport(actor)[source]#

Bases: Module

Traceable Dreamer policy head used by the generic export resolver.

Parameters:

actor (nn.Module)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Tensor

class world_models.export.ExportableAgentMixin[source]#

Bases: object

Mixin for non-nn.Module agents that delegates to the shared exporter.

export(path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export this agent or one of its contained modules for deployment.

Parameters:
  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • target (str | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

class world_models.export.IRISActorCriticExport(agent)[source]#

Bases: Module

Traceable IRIS policy/value head used by the generic export resolver.

Parameters:

agent (Any)

forward(frames)[source]#
Parameters:

frames (Tensor)

Return type:

tuple[Tensor, Tensor]

world_models.export.export_any(obj, path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export any TorchWM model/agent or a target module contained by it.

Parameters:
  • obj (Any)

  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • target (str | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

world_models.export.export_model(module, path, format='onnx', *, example_inputs=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#

Export a torch.nn.Module to ONNX, TorchScript, or TensorRT.

Parameters:
  • module (Module)

  • path (str | Path)

  • format (str)

  • example_inputs (Any | None)

  • input_names (list[str] | None)

  • output_names (list[str] | None)

  • dynamic_axes (dict[str, dict[int, str]] | None)

  • opset_version (int)

  • kwargs (Any)

Return type:

Path

world_models.export.install_export_method()[source]#

Install torch.nn.Module.export once for every Torch model class.

Return type:

None

Models sub-module - Core world model implementations.

Exported Components:
Agents (High-level training wrappers):
  • DreamerAgent: High-level Dreamer training API

  • JEPAAgent: JEPA agent for self-supervised learning

  • Planet: PlaNet planning agent

  • VisionTransformer: Vision Transformer for image encoding

  • ModularRSSM: Modular RSSM with swappable components

  • Genie: Generative Interactive Environment model

Core Models:
  • Dreamer: Core Dreamer implementation with RSSM, actor, critic

  • RSSM: Recurrent State-Space Model (Dreamer-style)

  • RecurrentStateSpaceModel: PlaNet-style RSSM

  • LatentActionModel: Latent action learning for Genie

  • DynamicsModel: Future frame prediction for Genie

Factory Functions:
  • create_genie, create_genie_small, create_genie_large

  • create_modular_rssm

  • create_latent_action_model, create_dynamics_model

Small, import-safe catalog of available environments and backends.

This module replaces the previous world_models.ui.catalog and is safe to import from lightweight CLI tools and tests without pulling in any UI dependencies.

Model catalog#

Core model families#

Key classes: Dreamer, DreamerAgent, RSSM, RecurrentStateSpaceModel, Planet, ModularRSSM, JEPAAgent, VisionTransformer, IRISAgent, IRISTransformer, IRISWorldModel, Genie, LatentActionModel, and DynamicsModel.

world_models.models.dreamer.get_available_memory()[source]#

Get available physical memory in bytes.

Return type:

int

world_models.models.dreamer.make_env(args)[source]#

Construct a Dreamer-compatible environment from DreamerConfig options.

Supports DMC, DMLab, Gym/Gymnasium, MuJoCo, Gymnasium Robotics, Procgen, Brax, BSuite, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit.

Parameters:

args (Any)

Return type:

Any

world_models.models.dreamer.preprocess_obs(obs)[source]#

Convert raw uint8 image observations to Dreamer float input space.

Images are scaled from [0, 255] to roughly [-0.5, 0.5], matching the normalization expected by Dreamer encoders.

Parameters:

obs (Tensor)

Return type:

Tensor

class world_models.models.dreamer.Dreamer(args, obs_shape, action_size, device, restore=False)[source]#

Bases: object

Core Dreamer training system combining world model, actor, and value nets.

This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O.

Parameters:
  • args (Any)

  • obs_shape (Any)

  • action_size (int)

  • device (device | str)

  • restore (bool)

classmethod from_config(config=None, *, obs_shape=None, action_size=None, device=None, restore=None, **overrides)[source]#

Build a core Dreamer model from a config object, dict, or YAML file.

obs_shape and action_size may be supplied directly. When either is omitted, this method constructs a temporary environment from the config to infer the model shapes.

Parameters:
  • config (DreamerConfig | dict[str, Any] | str | Path | None)

  • obs_shape (tuple[int, ...] | None)

  • action_size (int | None)

  • device (str | device | None)

  • restore (bool | None)

  • overrides (Any)

Return type:

Dreamer

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Load a Dreamer checkpoint from a local path/directory or the HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (DreamerConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

Dreamer

parameter_count(trainable_only=False)[source]#

Return the total number of parameters owned by the Dreamer modules.

Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#

Return a compact parameter-count summary for the Dreamer modules.

Return type:

dict[str, Any]

world_model_loss(obs, acs, rews, nonterms)[source]#
Parameters:
  • obs (Tensor)

  • acs (Tensor)

  • rews (Tensor)

  • nonterms (Tensor)

Return type:

Tensor

actor_loss()[source]#
Return type:

Tensor

value_loss()[source]#
Return type:

Tensor

train_one_batch()[source]#
Return type:

list[float]

act_with_world_model(obs, prev_state, prev_action, explore=False)[source]#
Parameters:
  • obs (Any)

  • prev_state (Any)

  • prev_action (Tensor)

  • explore (bool)

Return type:

tuple

act_and_collect_data(env, collect_steps)[source]#
Parameters:
  • env (Any)

  • collect_steps (int)

Return type:

ndarray

evaluate(env, eval_episodes, render=False)[source]#
Parameters:
  • env (Any)

  • eval_episodes (int)

  • render (bool)

Return type:

tuple

collect_random_episodes(env, seed_steps)[source]#
Parameters:
  • env (Any)

  • seed_steps (int)

Return type:

ndarray

save(save_path)[source]#
Parameters:

save_path (str)

Return type:

None

restore_checkpoint(ckpt_path, map_location=None)[source]#
Parameters:
  • ckpt_path (str | Path)

  • map_location (Any)

Return type:

None

class world_models.models.dreamer.DreamerAgent(config=None, **kwargs)[source]#

Bases: ExportableAgentMixin

High-level user API for running Dreamer experiments end to end.

It builds environments from config, initializes seeds and logging, instantiates Dreamer, and exposes simple train() / evaluate() methods.

Parameters:
  • config (Any)

  • kwargs (Any)

classmethod from_config(config=None, **overrides)[source]#

Build a high-level Dreamer agent from a config object, dict, or YAML file.

Parameters:
  • config (DreamerConfig | dict[str, Any] | str | Path | None)

  • overrides (Any)

Return type:

DreamerAgent

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Create a Dreamer agent and restore weights from a local path or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (DreamerConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

DreamerAgent

parameter_count(trainable_only=False)[source]#

Return the total number of Dreamer parameters.

Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#

Return a compact parameter-count summary for the wrapped Dreamer model.

Return type:

dict[str, Any]

train(total_steps=None)[source]#
Parameters:

total_steps (int | None)

Return type:

None

evaluate()[source]#
Return type:

tuple

class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer for latent dynamics learning.

The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:

  1. Deterministic State (h) – A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.

  2. Stochastic State (s) – A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).

The model operates in two modes:

  • Observe Mode – Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)

  • Imagine Mode – Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)

Architecture

  • Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1}

  • Process: GRU updates deterministic state, MLP computes stochastic prior/posterior

  • Output: Updated state (h_t, s_t) and distributions

State Representation

  • deter (h): GRU hidden state, captures sequential context

  • stoch (s): Stochastic latent, multi-modal uncertainty

  • mean/std: Parameters of the stochastic distribution

Usage with DreamerAgent:

rssm = RSSM(
    action_size=action_dim,
    stoch_size=30,      # Stochastic state dimension
    deter_size=200,     # Deterministic (GRU) state dimension
    hidden_size=200,    # MLP hidden layer size
    obs_embed_size=256,  # Observation embedding from encoder
    activation='elu'
)

# Observe with actual observation
posterior = rssm.observe_step(prev_state, prev_action, obs_embed)

# Imagine future without observation
prior = rssm.imagine_step(current_state, action)

Training

The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to

capture environment dynamics

  • Reconstruction loss from decoder ensures state captures observation info

Reference:

Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • hidden_size (int)

  • obs_embed_size (int)

  • activation (str)

init_state(batch_size, device)[source]#

Initialize RSSM state with zeros.

Parameters:
  • batch_size (int) – Number of parallel sequences

  • device (device) – torch device for tensors

Returns:

  • mean, std: Stochastic distribution parameters

  • stoch: Stochastic state sample

  • deter: Deterministic GRU hidden state

Return type:

Dictionary containing zero-initialized state components

get_dist(mean, std)[source]#

Create an Independent Normal distribution from mean and std.

Parameters:
  • mean (Tensor) – Location parameter

  • std (Tensor) – Scale parameter

Returns:

Independent Normal distribution with given parameters

Return type:

Independent

observe_step(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#

Update state using actual observation (observe mode).

In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.

Parameters:
  • prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)

  • obs_embed (Tensor) – Observation embedding from encoder, shape (B, obs_embed_size)

  • nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

A tuple (posterior, prior) of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.

Return type:

Tuple[dict, dict]

imagine_step(prev_state, prev_action, nonterm=tensor(1.))[source]#

Predict next state without observation (imagine mode).

In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.

Parameters:
  • prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})

  • prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)

  • nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)

Returns:

  • deter: Predicted deterministic state

  • mean, std, stoch: Prior stochastic state distribution

Return type:

Dictionary with predicted state containing

get_prior(prev_state, prev_action, nonterm=tensor(1.))[source]#
Parameters:
  • prev_state (dict)

  • prev_action (Tensor)

  • nonterm (Tensor)

Return type:

dict

get_posterior(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#

Compute posterior distribution over stochastic state.

The posterior incorporates observation information to produce a more accurate state estimate.

Parameters:
  • prev_state (dict) – Previous state dictionary

  • prev_action (Tensor) – Previous action

  • obs_embed (Tensor) – Observation embedding

  • nonterm (Tensor) – Termination mask

Returns:

Dictionary with posterior state (observation-informed). Note that the previous-state shape (B, ...) is preserved; the batch dimension is not flattened.

Return type:

dict

detach_state(state)[source]#

Detach state tensors from computation graph.

Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.

Parameters:

state (dict) – State dictionary with tensor values

Returns:

Detached state dictionary

Return type:

dict

seq_to_batch(state_dict)[source]#

Convert sequence state to batch format.

Parameters:

state_dict (dict) – Dictionary with sequence-dimension tensors (T, B, …)

Returns:

Dictionary with batch-dimension tensors (B*T, …)

Return type:

dict

observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#

Process a sequence of observations (observe mode rollout).

At each timestep we run observe_step once to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.

Parameters:
  • obs_embed (Tensor) – Observation embeddings, shape (T+1, B, obs_embed_size)

  • actions (Tensor) – Actions, shape (T, B, action_size)

  • nonterms (Tensor) – Non-termination flags, shape (T, B, 1)

  • init_state (dict) – Initial state dictionary

  • seq_len (int) – Sequence length T

Returns:

Dictionary with prior states stacked along the time axis posterior: Dictionary with posterior states stacked along the time axis

Return type:

prior

imagine_rollout(policy, init_state, horizon)[source]#

Generate imagined trajectory using policy (imagine mode rollout).

Parameters:
  • policy (Module) – Actor network that outputs actions from state features

  • init_state (dict) – Initial state dictionary

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined states for each step

Return type:

dict

forward(x, u)[source]#

Forward pass for training (computes sequence of states).

Parameters:
  • x (Tensor) – Observations, shape (B, T+1, C, H, W)

  • u (Tensor) – Actions, shape (B, T, action_size)

Returns:

List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)

Return type:

states

class world_models.models.rssm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#

Bases: Module

A Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.

Parameters:
  • action_size (int)

  • state_size (int)

  • latent_size (int)

  • hidden_size (int)

  • embed_size (int)

  • activation_function (str)

get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#

Returns the initial posterior given the observation.

Parameters:
  • enc (Tensor)

  • h_t (Tensor | None)

  • s_t (Tensor | None)

  • a_t (Tensor | None)

  • mean (bool)

Return type:

tuple[Tensor, Tensor]

deterministic_state_fwd(h_t, s_t, a_t)[source]#

Deterministic transition update.

Ensures a_t is 2D and matches batch dimension of h_t before concatenation. Accepts a_t shaped [B, action_size], [action_size] (expanded to [B, action_size]), or [B]/scalar (reshaped appropriately).

Parameters:
  • h_t (Tensor)

  • s_t (Tensor)

  • a_t (Tensor)

Return type:

Tensor

state_prior(h_t, sample=False)[source]#

Returns the prior distribution over the latent state given the deterministic state

Parameters:
  • h_t (Tensor)

  • sample (bool)

Return type:

tuple[Tensor, Tensor] | Tensor

state_posterior(h_t, e_t, sample=False)[source]#

Returns the state prior given the deterministic state and obs

Parameters:
  • h_t (Tensor)

  • e_t (Tensor)

  • sample (bool)

Return type:

tuple[Tensor, Tensor] | Tensor

pred_reward(h_t, s_t)[source]#
Parameters:
  • h_t (Tensor)

  • s_t (Tensor)

Return type:

Tensor

rollout_prior(act, h_t, s_t)[source]#
Parameters:
  • act (Tensor)

  • h_t (Tensor)

  • s_t (Tensor)

Return type:

tuple[Tensor, Tensor]

forward(x, u)[source]#

Forward through the RSSM for a batch of sequences.

Parameters:
  • x (Tensor) – Tensor [B, T+1, C, H, W] (observations including initial frame)

  • u (Tensor) – Tensor [B, T, action_size] (actions for T steps)

Returns:

list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]

Return type:

states

class world_models.models.planet.Planet(env, bit_depth=5, device=None, state_size=200, latent_size=30, embedding_size=1024, memory_size=100, policy_cfg=None, headless=False, max_episode_steps=None, action_repeats=1, results_dir=None)[source]#

Bases: ExportableAgentMixin

High-level Planet wrapper.

Usage example:

from world_models.models.planet import Planet p = Planet(env=’CartPole-v1’, bit_depth=5) p.train(epochs=50)

Parameters:
  • env (Any)

  • bit_depth (int)

  • device (device | None)

  • state_size (int)

  • latent_size (int)

  • embedding_size (int)

  • memory_size (int)

  • policy_cfg (dict | None)

  • headless (bool)

  • max_episode_steps (int | None)

  • action_repeats (int)

  • results_dir (str | None)

warmup(n_episodes=1, random_policy=True)[source]#

Collect n_episodes of rollouts into memory (used as warmup).

Parameters:
  • n_episodes (int)

  • random_policy (bool)

Return type:

None

train(epochs=100, steps_per_epoch=150, batch_size=32, H=50, beta=1.0, save_every=25, record_grads=False, results_dir=None, scheduler_type='step', scheduler_kwargs=None)[source]#

High-level training loop. Delegates single-step training to the existing train function.

Parameters:
  • scheduler_type (str) – Type of scheduler to use (“step”, “cosine”, “exponential”, “plateau”, None)

  • scheduler_kwargs (dict) – Additional arguments for the scheduler

  • epochs (int)

  • steps_per_epoch (int)

  • batch_size (int)

  • H (int)

  • beta (float)

  • save_every (int)

  • record_grads (bool)

  • results_dir (str | None)

Return type:

str

Mixture Density Recurrent Neural Network (MDRNN) model implementation.

This module provides implementations of MDRNN models for world modeling. The MDRNN is used to predict future latent states given current latent states and actions, using a Gaussian Mixture Model (GMM) for the output.

Reference:

Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111

class world_models.models.mdrnn.MDRNN(latents, actions, hiddens, gaussians)[source]#

Bases: _MDRNNBase

MDRNN model for multi-step sequence prediction.

This model processes entire sequences of latent states and actions, predicting the next latent state using a Gaussian Mixture Model (GMM). It also predicts rewards and terminal states.

Parameters:
  • latents (int) – Dimensionality of latent space (input and output).

  • actions (int) – Dimensionality of action space.

  • hiddens (int) – Number of hidden units in LSTM.

  • gaussians (int) – Number of Gaussian components in GMM output.

Example

>>> mdrnn = MDRNN(latents=32, actions=3, hiddens=256, gaussians=5)
>>> actions = torch.randn(10, 4, 3)  # seq_len, batch, action
>>> latents = torch.randn(10, 4, 32)  # seq_len, batch, latent
>>> mus, sigmas, logpi, rs, ds = mdrnn(actions, latents)
>>> # mus.shape = (10, 4, 5, 32)
forward(actions, latents)[source]#

Multi-step forward pass through the MDRNN.

Parameters:
  • actions (Tensor) – (SEQ_LEN, BSIZE, ASIZE) Tensor of actions.

  • latents (Tensor) – (SEQ_LEN, BSIZE, LSIZE) Tensor of latent states.

Returns:

  • mus: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) GMM means

  • sigmas: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) GMM standard deviations

  • logpi: (SEQ_LEN, BSIZE, N_GAUSS) log GMM weights

  • rs: (SEQ_LEN, BSIZE) predicted rewards

  • ds: (SEQ_LEN, BSIZE) predicted terminal state logits

Return type:

Tuple of

get_init_hidden(batch_size=1)[source]#

Return initial hidden state for the LSTM.

Parameters:

batch_size (int) – Number of sequences in the batch.

Returns:

Tuple of (h, c) with shapes (batch_size, hiddens).

Return type:

tuple[Tensor, Tensor]

class world_models.models.mdrnn.MDRNNCell(latents, actions, hiddens, gaussians)[source]#

Bases: _MDRNNBase

MDRNN model for single-step forward prediction.

This model processes a single step of latent state and action,

This model processes a single step of latent state and action, predicting the next latent state using a Gaussian Mixture Model (GMM). It also predicts rewards and terminal states. Useful for real-time inference.

Parameters:
  • latents (int) – Dimensionality of latent space (input and output).

  • actions (int) – Dimensionality of action space.

  • hiddens (int) – Number of hidden units in LSTMCell.

  • gaussians (int) – Number of Gaussian components in GMM output.

Example

>>> cell = MDRNNCell(latents=32, actions=3, hiddens=256, gaussians=5)
>>> action = torch.randn(4, 3)  # batch, action
>>> latent = torch.randn(4, 32)  # batch, latent
>>> hidden = (torch.randn(4, 256), torch.randn(4, 256))
>>> mus, sigmas, logpi, r, d, next_hidden = cell(action, latent, hidden)
forward(action, latent, hidden)[source]#

Single-step forward pass through the MDRNN cell.

Parameters:
  • action (Tensor) – (BSIZE, ASIZE) Tensor of actions for current batch.

  • latent (Tensor) – (BSIZE, LSIZE) Tensor of latent states for current batch.

  • hidden (Tuple[Tensor, Tensor]) – Tuple of (h, c) hidden states for LSTMCell.

Returns:

  • mus: (BSIZE, N_GAUSS, LSIZE) GMM means

  • sigmas: (BSIZE, N_GAUSS, LSIZE) GMM standard deviations

  • logpi: (BSIZE, N_GAUSS) log GMM weights

  • r: (BSIZE,) predicted rewards

  • d: (BSIZE,) predicted terminal state logits

  • next_hidden: Tuple of (h, c) next hidden states

Return type:

Tuple of

get_init_hidden(batch_size=1)[source]#
Parameters:

batch_size (int)

Return type:

Tuple[Tensor, Tensor]

Linear Controller for World Models.

This module provides a simple linear controller that maps latent states and recurrent hidden states to actions. The controller is trained using CMA-ES (Covariance Matrix Adaptation Evolution Strategy).

Reference:

Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111

class world_models.models.controller.Controller(latent_size, hidden_size, action_size)[source]#

Bases: Module

Linear controller that maps latent + hidden state to actions.

This is a simple linear controller that takes the latent state and recurrent hidden state as input and outputs actions. It is trained separately from the world model using black-box optimization (CMA-ES).

Parameters:
  • latent_size (int)

  • hidden_size (int)

  • action_size (int)

latent_size#

Dimensionality of latent state from VAE.

hidden_size#

Dimensionality of RSSM hidden state.

action_size#

Dimensionality of action space.

Example

>>> controller = Controller(latent_size=32, hidden_size=200, action_size=3)
>>> state = torch.cat([latent, hidden], dim=-1)
>>> action = controller(state)
forward(state)[source]#

Compute actions from latent and hidden states.

Parameters:

state (Tensor) – Concatenated [latent, hidden] state tensor.

Returns:

Action tensor of shape (batch, action_size).

Return type:

Tensor

Modular RSSM with swappable encoder/decoder/backbone components.

This module provides a flexible architecture for world model research, allowing researchers to easily swap different encoder, decoder, and backbone implementations for ablations and experimentation.

class world_models.models.modular_rssm.EncoderBase(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for observation encoders.

Parameters:
  • args (Any)

  • kwargs (Any)

abstractmethod forward(obs)[source]#

Encode observations to embeddings.

Parameters:

obs (Tensor)

Return type:

Tensor

embed_size: int#
get_embed_size()[source]#

Return the embedding size. Override in subclasses.

Return type:

int

class world_models.models.modular_rssm.DecoderBase(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for observation decoders.

Parameters:
  • args (Any)

  • kwargs (Any)

abstractmethod forward(features)[source]#

Decode latent features to observation distributions.

Parameters:

features (Tensor)

Return type:

Any

class world_models.models.modular_rssm.BackboneBase(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for recurrent dynamics backbones.

Parameters:
  • args (Any)

  • kwargs (Any)

abstractmethod forward(state, action, obs_embed=None, nonterm=1.0)[source]#

Process one step of dynamics. Returns (prior, posterior).

Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

abstractmethod init_state(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

stoch_size: int#
deter_size: int#
class world_models.models.modular_rssm.ConvEncoder(input_shape, embed_size, activation='elu', depth=32)[source]#

Bases: EncoderBase

Convolutional encoder from Dreamer (image observations).

Parameters:
  • input_shape (Tuple[int, int, int])

  • embed_size (int)

  • activation (str)

  • depth (int)

forward(obs)[source]#
Parameters:

obs (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.MLPEncoder(input_dim, embed_size, hidden_sizes=[256, 256], activation='elu')[source]#

Bases: EncoderBase

MLP encoder for state-based observations.

Parameters:
  • input_dim (int)

  • embed_size (int)

  • hidden_sizes (List[int])

  • activation (str)

forward(obs)[source]#
Parameters:

obs (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.ViTEncoder(input_shape, embed_size, patch_size=8, depth=6, num_heads=8, mlp_ratio=4.0, activation='gelu')[source]#

Bases: EncoderBase

Vision Transformer encoder for image observations.

Parameters:
  • input_shape (Tuple[int, int, int])

  • embed_size (int)

  • patch_size (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • activation (str)

forward(obs)[source]#
Parameters:

obs (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.TransformerBlock(embed_size, num_heads, mlp_ratio, activation)[source]#

Bases: Module

Transformer block for ViT encoder.

Parameters:
  • embed_size (int)

  • num_heads (int)

  • mlp_ratio (float)

  • activation (str)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.modular_rssm.ConvDecoder(stoch_size, deter_size, output_shape, activation='elu', depth=32)[source]#

Bases: DecoderBase

Convolutional decoder for image observations.

Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (Tuple[int, int, int])

  • activation (str)

  • depth (int)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Any

class world_models.models.modular_rssm.MLPDecoder(stoch_size, deter_size, output_dim, hidden_sizes=[256, 256], activation='elu', dist='normal')[source]#

Bases: DecoderBase

MLP decoder for state-based observations.

Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_dim (int)

  • hidden_sizes (List[int])

  • activation (str)

  • dist (str)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Any

class world_models.models.modular_rssm.GRUBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#

Bases: BackboneBase

GRU-based recurrent dynamics backbone (standard RSSM).

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • hidden_size (int)

  • embed_size (int)

  • activation (str)

property embedding_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

class world_models.models.modular_rssm.LSTMBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#

Bases: BackboneBase

LSTM-based recurrent dynamics backbone.

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • hidden_size (int)

  • embed_size (int)

  • activation (str)

property embedding_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

class world_models.models.modular_rssm.TransformerBackbone(action_size, stoch_size, deter_size, embed_size, num_heads=4, num_layers=2, activation='gelu')[source]#

Bases: BackboneBase

Transformer-based dynamics backbone for long-range dependencies.

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • embed_size (int)

  • num_heads (int)

  • num_layers (int)

  • activation (str)

property embedding_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Parameters:
  • state (Dict[str, Tensor])

  • action (Tensor)

  • obs_embed (Tensor | None)

  • nonterm (float)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

class world_models.models.modular_rssm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#

Bases: Module

Modular RSSM with swappable encoder, decoder, and backbone.

This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer

Example

>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024)
>>> decoder = ConvDecoder(32, 200, (3, 64, 64))
>>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024)
>>> rssm = ModularRSSM(encoder, decoder, backbone)
Parameters:
property stoch_size: int#
property deter_size: int#
property embed_size: int#
init_state(batch_size, device)[source]#
Parameters:
  • batch_size (int)

  • device (device)

Return type:

Dict[str, Tensor]

get_dist(mean, std)[source]#
Parameters:
  • mean (Tensor)

  • std (Tensor)

Return type:

Distribution

observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • obs (Tensor)

  • nonterm (Any)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Parameters:
  • prev_state (Dict[str, Tensor])

  • prev_action (Tensor)

  • nonterm (Any)

Return type:

Dict[str, Tensor]

observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
Parameters:
  • obs (Tensor)

  • actions (Tensor)

  • nonterms (Tensor)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

imagine_rollout(actor, prev_state, horizon)[source]#
Parameters:
  • actor (Module)

  • prev_state (Dict[str, Tensor])

  • horizon (int)

Return type:

Dict[str, Tensor]

decode_observation(features)[source]#
Parameters:

features (Tensor)

Return type:

Tensor

decode_reward(features)[source]#
Parameters:

features (Tensor)

Return type:

Tensor

detach_state(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

seq_to_batch(state)[source]#
Parameters:

state (Dict[str, Tensor])

Return type:

Dict[str, Tensor]

world_models.models.modular_rssm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#

Factory function to create a modular RSSM with specified components.

Parameters:
  • encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)

  • decoder_type (str) – Type of decoder (“conv”, “mlp”)

  • backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)

  • obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state

  • action_size (int) – Action space dimension

  • stoch_size (int) – Stochastic latent dimension

  • deter_size (int) – Deterministic hidden dimension

  • embed_size (int) – Encoder embedding dimension

  • hidden_size (int) – Hidden layer dimension

  • activation (str) – Activation function name

  • kwargs (Any)

Returns:

Configured ModularRSSM instance

Return type:

ModularRSSM

class world_models.models.jepa_agent.JEPAAgent(config=None, **kwargs)[source]#

Bases: ExportableAgentMixin

Convenience interface for configuring and launching JEPA training runs.

Accepts a JEPAConfig plus keyword overrides, prepares output folders, and delegates execution to the JEPA training entrypoint.

Parameters:
classmethod from_config(config=None, **overrides)[source]#

Build a JEPA agent from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (JEPAConfig | dict[str, Any] | str | Path | None)

  • overrides (Any)

Return type:

JEPAAgent

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#

Create a JEPA agent from local/HF Hub config and checkpoint metadata.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (JEPAConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • overrides (Any)

Return type:

JEPAAgent

parameter_count(trainable_only=False)[source]#

JEPA models are constructed inside training, so no parameters are resident.

Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#

Return the configured JEPA run metadata.

Return type:

dict[str, Any]

train()[source]#
Return type:

None

world_models.models.vit.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate fixed 2D sine/cosine positional embeddings on a square patch grid.

Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding.

Parameters:
  • embed_dim (int)

  • grid_size (int)

  • cls_token (bool)

Return type:

ndarray

world_models.models.vit.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)[source]#

Build 2D sine/cosine embeddings from precomputed meshgrid coordinates.

The final embedding concatenates independent encodings for vertical and horizontal coordinates.

Parameters:
  • embed_dim (int)

  • grid (ndarray)

Return type:

ndarray

world_models.models.vit.get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate 1D sine/cosine positional embeddings for integer positions.

Useful for sequence-style positional encoding and as a building block for 2D embedding construction.

Parameters:
  • embed_dim (int)

  • grid_size (int)

  • cls_token (bool)

Return type:

ndarray

world_models.models.vit.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#

Generate 1D sine/cosine positional embeddings from explicit positions.

Positions are projected onto a log-frequency basis and encoded with sine and cosine components.

Parameters:
  • embed_dim (int)

  • pos (ndarray)

Return type:

ndarray

world_models.models.vit.drop_path(x, drop_prob=0.0, training=False)[source]#

Apply stochastic depth (DropPath) regularization to residual branches.

Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude.

Parameters:
  • x (Tensor)

  • drop_prob (float)

  • training (bool)

Return type:

Tensor

class world_models.models.vit.DropPath(drop_prob=None)[source]#

Bases: Module

Module wrapper around the functional drop_path stochastic depth utility.

Parameters:

drop_prob (float | None)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.vit.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#

Bases: Module

Two-layer feed-forward network used inside transformer blocks.

Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern.

Parameters:
  • in_features (int)

  • hidden_features (int | None)

  • out_features (int | None)

  • act_layer (type[Module])

  • drop (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.vit.Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Multi-head self-attention block for token sequences.

Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout.

Parameters:
  • dim (int)

  • num_heads (int)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • attn_drop (float)

  • proj_drop (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Transformer encoder block combining attention and MLP residual branches.

Each branch uses pre-normalization and optional stochastic depth.

Parameters:
  • dim (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop (float)

  • attn_drop (float)

  • drop_path (float)

  • act_layer (type[Module])

  • norm_layer (type[Module])

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.vit.PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768)[source]#

Bases: Module

Image to Patch Embedding

Parameters:
  • img_size (int)

  • patch_size (int)

  • in_chans (int)

  • embed_dim (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.vit.ConvEmbed(channels, strides, img_size=224, in_chans=3, batch_norm=True)[source]#

Bases: Module

3x3 Convolution stems for ViT following ViTC models

Parameters:
  • channels (list[int])

  • strides (list[int])

  • img_size (int)

  • in_chans (int)

  • batch_norm (bool)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.vit.VisionTransformerPredictor(num_patches, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

Parameters:
  • num_patches (int)

  • embed_dim (int)

  • predictor_embed_dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

  • init_std (float)

  • kwargs (Any)

fix_init_weight()[source]#
Return type:

None

forward(x, masks_x, masks)[source]#
Parameters:
  • x (Tensor)

  • masks_x (Tensor | list[Tensor])

  • masks (Tensor | list[Tensor])

Return type:

Tensor

class world_models.models.vit.VisionTransformer(img_size=[224], patch_size=16, in_chans=3, embed_dim=768, predictor_embed_dim=384, depth=12, predictor_depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

Parameters:
  • img_size (list[int])

  • patch_size (int)

  • in_chans (int)

  • embed_dim (int)

  • predictor_embed_dim (int)

  • depth (int)

  • predictor_depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

  • init_std (float)

  • kwargs (Any)

fix_init_weight()[source]#
Return type:

None

forward(x, masks=None)[source]#
Parameters:
  • x (Tensor)

  • masks (Tensor | list[Tensor] | None)

Return type:

Tensor

interpolate_pos_encoding(x, pos_embed)[source]#
Parameters:
  • x (Tensor)

  • pos_embed (Tensor)

Return type:

Tensor

world_models.models.vit.vit_predictor(**kwargs)[source]#

Factory for a JEPA predictor transformer with sensible defaults.

Parameters:

kwargs (Any)

Return type:

VisionTransformerPredictor

world_models.models.vit.vit_tiny(patch_size=16, **kwargs)[source]#

Factory for a tiny Vision Transformer encoder backbone.

Parameters:
  • patch_size (int)

  • kwargs (Any)

Return type:

Any

world_models.models.vit.vit_small(patch_size=16, **kwargs)[source]#

Factory for a small Vision Transformer encoder backbone.

Parameters:
  • patch_size (int)

  • kwargs (Any)

Return type:

Any

world_models.models.vit.vit_base(patch_size=16, **kwargs)[source]#

Factory for a base Vision Transformer encoder backbone.

Parameters:
  • patch_size (int)

  • kwargs (Any)

Return type:

Any

world_models.models.vit.vit_large(patch_size=16, **kwargs)[source]#

Factory for a large Vision Transformer encoder backbone.

Parameters:
  • patch_size (int)

  • kwargs (Any)

Return type:

Any

world_models.models.vit.vit_huge(patch_size=16, **kwargs)[source]#

Factory for a huge Vision Transformer encoder backbone.

Parameters:
  • patch_size (int)

  • kwargs (Any)

Return type:

Any

world_models.models.vit.vit_giant(patch_size=16, **kwargs)[source]#

Factory for a giant Vision Transformer encoder backbone.

Parameters:
  • patch_size (int)

  • kwargs (Any)

Return type:

Any

world_models.models.iris_agent.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#

Compute λ-return target for value function training.

Parameters:
  • rewards (Tensor) – Rewards (B, T)

  • values (Tensor) – Value estimates (B, T+1)

  • discounts (Tensor) – Discount factors (B, T)

  • lambda_coef (float) – Lambda parameter for bootstrapping

Returns:

λ-return targets (B, T)

Return type:

lambda_returns

class world_models.models.iris_agent.IRISAgent(config, action_size, device)[source]#

Bases: Module

Complete IRIS Agent with world model and policy.

Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning

Parameters:
  • config (IRISConfig)

  • action_size (int)

  • device (device)

classmethod from_config(config=None, *, action_size, device=None, **overrides)[source]#

Build an IRIS agent from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (IRISConfig | dict[str, Any] | str | Path | None)

  • action_size (int)

  • device (device | str | None)

  • overrides (Any)

Return type:

IRISAgent

classmethod from_pretrained(pretrained_model_name_or_path, *, action_size=None, device=None, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#

Load an IRIS agent checkpoint from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • action_size (int | None)

  • device (device | str | None)

  • config (IRISConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • overrides (Any)

Return type:

IRISAgent

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

forward_actor_critic(frames, hidden=None)[source]#

Forward pass through actor-critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state

Returns:

(B, T, action_size) values: (B, T) hidden_state: (h, c)

Return type:

action_logits

act(frame, epsilon=0.0, temperature=1.0)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • epsilon (float) – Random action probability

  • temperature (float) – Action distribution temperature

Returns:

Selected actions (B,)

Return type:

actions

imagine_rollout(initial_frame, horizon=20)[source]#

Generate imagined trajectories using world model.

Parameters:
  • initial_frame (Tensor) – Starting frame (B, C, H, W)

  • horizon (int) – Number of steps to imagine

Returns:

Dictionary with imagined rollout data

Return type:

trajectory

update_autoencoder(frames)[source]#

Update discrete autoencoder.

Parameters:

frames (Tensor) – Training frames (B, C, H, W)

Returns:

Dictionary of loss values

Return type:

losses

update_transformer(frames, actions, rewards, terminals)[source]#

Update transformer world model.

Parameters:
  • frames (Tensor) – Frame sequence

  • actions (Tensor) – Actions taken

  • rewards (Tensor) – Rewards received

  • terminals (Tensor) – Terminal flags

Returns:

Dictionary of loss values

Return type:

losses

update_actor_critic(imagined_trajectory)[source]#

Update actor-critic in imagination.

Parameters:

imagined_trajectory (dict) – Dictionary from imagine_rollout

Returns:

Dictionary of loss values

Return type:

losses

save(path)[source]#

Save agent state.

Parameters:

path (str)

Return type:

None

load(path)[source]#

Load agent state.

Parameters:

path (str)

Return type:

None

class world_models.models.iris_transformer.IRISTransformer(vocab_size=512, tokens_per_frame=16, action_size=18, embed_dim=256, num_layers=10, num_heads=4, dropout=0.1, gradient_checkpointing=False)[source]#

Bases: Module

Autoregressive Transformer for world modeling.

Models the dynamics of the environment by predicting: - Next frame tokens (transition model) - Rewards - Episode termination

The Transformer operates on sequences of interleaved frame tokens and actions.

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • action_size (int)

  • embed_dim (int)

  • num_layers (int)

  • num_heads (int)

  • dropout (float)

  • gradient_checkpointing (bool)

forward(tokens, actions, mask=None)[source]#

Forward pass through the Transformer world model.

Parameters:
  • tokens (Tensor) – Frame tokens (B, T, K) where T is timesteps

  • actions (Tensor) – Actions (B, T)

  • mask (Tensor | None) – Optional attention mask

Returns:

Next token predictions (B, T, K, vocab_size) rewards: Predicted rewards (B, T) terminations: Predicted terminations (B, T, 2)

Return type:

token_logits

predict_next_tokens(tokens, actions)[source]#

Predict the next frame tokens autoregressively.

Used during imagination rollouts.

Parameters:
  • tokens (Tensor) – Current frame tokens (B, K)

  • actions (Tensor) – Actions taken (B,)

Returns:

Next frame token predictions (B, K, vocab_size) action_hidden: Hidden states for reward prediction (B, embed_dim)

Return type:

token_logits

sample_next_tokens(tokens, actions, temperature=1.0)[source]#

Sample next tokens from the distribution.

Parameters:
  • tokens (Tensor) – Current frame tokens (B, K)

  • actions (Tensor) – Actions taken (B,)

  • temperature (float) – Sampling temperature (higher = more random)

Returns:

Sampled token indices (B, K) log_probs: Log probabilities of sampled tokens (B, K)

Return type:

sampled_tokens

class world_models.models.iris_transformer.IRISWorldModel(encoder, decoder, transformer)[source]#

Bases: Module

Complete IRIS World Model combining autoencoder and transformer.

This is the core component that learns environment dynamics entirely in the “imaginary” latent space.

Parameters:
forward(observations, actions)[source]#

Full world model forward pass.

Parameters:
  • observations (Tensor) – Image sequence (B, T+1, C, H, W)

  • actions (Tensor) – Actions (B, T)

Returns:

Dictionary with predicted tokens, rewards, terminations losses: Dictionary with loss components

Return type:

predictions

imagine(initial_tokens, policy, horizon=20, temperature=1.0)[source]#

Generate imagined trajectories.

Parameters:
  • initial_tokens (Tensor) – Initial frame tokens (B, K)

  • policy (Module) – Policy network to sample actions

  • horizon (int) – Number of steps to imagine

  • temperature (float) – Sampling temperature for token prediction

Returns:

Dictionary with imagined trajectories

Return type:

imagined

class world_models.models.genie.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Genie: Generative Interactive Environment.

A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions

Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).

Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)

The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse

At inference, latent actions are stopgrad’d when passed to dynamics model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_decoder_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • latent_action_depth (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

classmethod from_config(config=None, **overrides)[source]#

Build Genie from a config object, dict, YAML file, or YAML string.

Parameters:
Return type:

Genie

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Load Genie weights from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (GenieConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

Genie

save_pretrained(path)[source]#

Save Genie weights and config in a from_pretrained-compatible format.

Parameters:

path (str | Path)

Return type:

None

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

forward(video, mask_prob=0.5, training_phase='all')[source]#

Full forward pass through all components.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing losses and predictions

Return type:

Dict[str, Tensor]

training_step(video, mask_prob=0.5, training_phase='all')[source]#

Single training step computing all losses.

Parameters:
  • video (Tensor) – (B, C, T, H, W) input video

  • mask_prob (float) – Probability for random masking in dynamics

  • training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”

Returns:

Dictionary containing all losses for backpropagation

Return type:

Dict[str, Tensor]

encode_video(video)[source]#

Encode video to discrete tokens.

Parameters:

video (Tensor) – (B, C, T, H, W)

Returns:

(B, T, H*W)

Return type:

video_tokens

infer_actions(frames)[source]#

Infer latent actions from a sequence of frames.

Parameters:

frames (Tensor) – (B, C, T, H, W) video frames

Returns:

(B, T-1) inferred latent action indices

Return type:

latent_actions

generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#

Generate video frames given a prompt frame and actions.

Parameters:
  • prompt_frame (Tensor) – (B, C, H, W) initial frame

  • num_frames (int) – Total number of frames to generate

  • actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random

  • use_maskgit (bool) – Whether to use MaskGIT sampling

Returns:

(B, C, num_frames, H, W)

Return type:

generated_video

play(current_frame, action, current_frames=None)[source]#

Play step - generate next frame given current frame and action.

Parameters:
  • current_frame (Tensor) – (B, C, H, W) current frame

  • action (Tensor) – (B,) latent action indices

  • current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame

Returns:

(B, C, H, W)

Return type:

next_frame

get_num_parameters()[source]#

Return total number of parameters.

Return type:

int

world_models.models.genie.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.models.genie.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#

Create a smaller Genie model for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

world_models.models.genie.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#

Create the full 11B parameter Genie model (approximate).

Parameters:
  • num_frames (int)

  • image_size (int)

  • use_bfloat16 (bool)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

Genie

class world_models.models.latent_action_model.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: Module

Latent Action Model (LAM) for unsupervised action learning.

Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.

Based on Genie paper - learns actions without action labels from Internet videos.

Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

encode(x_prev, x_next)[source]#

Encode frames to latent actions.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)

Return type:

latent_actions

decode(x_prev, z_q)[source]#

Decode latent actions to predict next frame.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W) - will mask all but first

  • z_q (Tensor) – Quantized action embeddings (B, T-1, embedding_dim)

Returns:

(B, C, H, W)

Return type:

predicted_next_frame

forward(x_prev, x_next)[source]#

Full forward pass: encode to get actions, decode to reconstruct.

Parameters:
  • x_prev (Tensor) – Previous frames (B, C, T, H, W)

  • x_next (Tensor) – Next frame (B, C, H, W)

Returns:

Dictionary with losses and outputs

Return type:

Dict[str, Tensor]

world_models.models.latent_action_model.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#

Factory function to create a Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

Return type:

LatentActionModel

class world_models.models.dynamics_model.MaskGITSampler(num_steps=25, temperature=2.0, mask_schedule='cosine')[source]#

Bases: object

MaskGIT sampling for token-based video generation.

Uses iterative refinement with a mask schedule to progressively reveal tokens during generation.

Parameters:
  • num_steps (int)

  • temperature (float)

  • mask_schedule (str)

get_mask_prob(step)[source]#

Get mask probability for given step.

Parameters:

step (int)

Return type:

float

sample(logits, mask, step)[source]#

Sample tokens from logits with mask.

Parameters:
  • logits (Tensor) – (B, T, vocab_size)

  • mask (Tensor) – (B, T) - 1 for tokens to predict, 0 for fixed tokens

  • step (int) – Current step in refinement

Returns:

(B, T) new_mask: (B, T)

Return type:

sampled_tokens

class world_models.models.dynamics_model.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=True)[source]#

Bases: Module

Dynamics Model for action-controllable video generation.

A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.

Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • gradient_checkpointing (bool)

forward(video_tokens, actions, mask_prob=0.0)[source]#

Forward pass for training.

Parameters:
  • video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T

  • actions (Tensor) – (B, T) - latent action indices for frames 1 to T

  • mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)

Returns:

(B, T, H*W, vocab_size)

Return type:

logits

sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#

Sample future frames using MaskGIT.

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • sampler (MaskGITSampler | None) – MaskGIT sampler instance

Returns:

(B, num_frames, N)

Return type:

generated_tokens

autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#

Simple autoregressive sampling (token by token).

Parameters:
  • prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens

  • prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames

  • num_frames (int) – Total number of frames to generate

  • temperature (float) – Sampling temperature

Returns:

(B, num_frames, N)

Return type:

generated_tokens

world_models.models.dynamics_model.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#

Factory function to create a Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

Return type:

DynamicsModel

Diffusion and DIAMOND components#

Key classes: DDPM, DiT, DiffusionUNet, EDMPreconditioner, EulerSampler, RewardTerminationModel, and ActorCriticNetwork.

Diffusion sub-module - Diffusion model components for world models.

Exported Components:
  • DiT: Diffusion Transformer model

  • PatchEmbed: Image patch embedding

  • PatchUnEmbed: Patch unembedding (decode tokens to image)

  • DDPM: Denoising Diffusion Probabilistic Model implementation

  • ActorCriticNetwork: DIAMOND actor-critic network

  • RewardTerminationModel: Reward/termination prediction model

  • sinusoidal_time_embedding: Time embedding for diffusion models

class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end)[source]#

Bases: Module

Utility module implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

Parameters:
  • timesteps (int)

  • beta_start (float)

  • beta_end (float)

q_sample(x_start, t, noise=None)[source]#
Parameters:
  • x_start (Tensor)

  • t (Tensor)

  • noise (Tensor | None)

Return type:

Tensor

p_sample(model, x_t, t)[source]#
Parameters:
  • model (Module)

  • x_t (Tensor)

  • t (Tensor)

Return type:

Tensor

sample(model, n, img_size, channels)[source]#
Parameters:
  • model (Module)

  • n (int)

  • img_size (int)

  • channels (int)

Return type:

Tensor

world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.

Math:

embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)

Parameters:
  • timesteps (Tensor) – Tensor of integer timesteps, shape (B,) or (B, 1)

  • dim (int) – Embedding dimension (must be even)

Returns:

Tensor of shape (B, dim) with sinusoidal embeddings

Return type:

Tensor

Usage with DiT:

t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)

# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization

class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.

Process:
  1. Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches

  2. Each patch is projected to embed_dim via linear layer (Conv2d)

  3. Learnable positional embeddings are added for spatial information

Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)

Parameters:
  • img_size (int) – Image size (assumes square), e.g., 32 for CIFAR

  • patch_size (int) – Size of each patch (typically 4, 8, or 16)

  • in_channels (int) – Number of input channels (3 for RGB)

  • embed_dim (int) – Output dimension for each patch token

Usage with DiT:

patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

Parameters:
  • img_size (int)

  • patch_size (int)

  • embed_dim (int)

  • out_channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]#

Bases: Module

Conditioned transformer block used inside the DiT backbone.

Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.

Parameters:
  • d_model (int)

  • n_heads (int)

  • mlp_ratio (float)

  • drop (float)

  • t_dim (int)

forward(x, t_emb)[source]#
Parameters:
  • x (Tensor)

  • t_emb (Tensor)

Return type:

Tensor

class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

Parameters:
  • img_size (int)

  • patch_size (int)

  • in_channels (int)

  • d_model (int)

  • depth (int)

  • heads (int)

  • drop (float)

  • t_dim (int)

forward(x, t)[source]#
Parameters:
  • x (Tensor)

  • t (Tensor)

Return type:

Tensor

classmethod from_config(config=None, **overrides)[source]#

Build DiT from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (DiTConfig | dict[str, Any] | str | Path | None)

  • overrides (Any)

Return type:

DiT

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#

Load DiT weights from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (DiTConfig | dict[str, Any] | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • map_location (str | device | None)

  • overrides (Any)

Return type:

DiT

save_pretrained(path)[source]#

Save DiT weights and config in a from_pretrained-compatible format.

Parameters:

path (str | Path)

Return type:

None

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict[str, Any]

classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
Parameters:
  • epochs (int)

  • dataset (Any)

  • batch_size (int)

  • lr (float)

  • img_size (int)

  • channels (int)

  • patch (int)

  • width (int)

  • depth (int)

  • heads (int)

  • drop (float)

  • timesteps (int)

  • beta_start (float)

  • beta_end (float)

  • ema (bool)

  • ema_decay (float)

  • workdir (str)

  • root_path (str)

  • image_folder (str | None)

  • crop_size (int)

  • download (bool)

  • copy_data (bool)

  • subset_file (str | None)

  • val_split (float | None)

Return type:

None

world_models.models.diffusion.DiT.create_dit(config=None, **overrides)[source]#

Create a DiT from a DiTConfig or keyword overrides.

The public factory API works with config objects, while DiT itself has a compact constructor. This adapter keeps the lower-level model constructor unchanged and maps the public config fields onto the expected arguments.

Parameters:
  • config (Any)

  • overrides (Any)

Return type:

DiT

class world_models.models.diffusion.diamond_diffusion.AdaptiveGroupNorm(num_groups, num_channels, cond_dim)[source]#

Bases: Module

Adaptive Group Normalization that conditions on actions and diffusion time.

Parameters:
  • num_groups (int)

  • num_channels (int)

  • cond_dim (int)

forward(x, cond)[source]#
Parameters:
  • x (Tensor) – Input tensor [B, C, H, W]

  • cond (Tensor) – Conditioning tensor [B, cond_dim]

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.ResBlock(in_channels, out_channels, cond_dim, dropout=0.0)[source]#

Bases: Module

Residual block with adaptive group normalization.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • dropout (float)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.AttentionBlock(channels, cond_dim)[source]#

Bases: Module

Self-attention block for U-Net.

Parameters:
  • channels (int)

  • cond_dim (int)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.TimestepEmbedding(dim, freq_dim=256)[source]#

Bases: Module

Sinusoidal timestep embedding.

Parameters:
  • dim (int)

  • freq_dim (int)

forward(t)[source]#
Parameters:

t (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.DownBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#

Bases: Module

Downsampling block for U-Net encoder.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • num_res_blocks (int)

  • attention (bool)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.UpBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#

Bases: Module

Upsampling block for U-Net decoder with skip connections.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • num_res_blocks (int)

  • attention (bool)

forward(x, cond, skip=None)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

  • skip (Tensor | None)

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.DiffusionUNet(obs_channels=3, num_conditioning_frames=4, base_channels=64, channel_multipliers=(1, 1, 1, 1), num_res_blocks=2, cond_dim=256, action_dim=18)[source]#

Bases: Module

U-Net architecture for EDM diffusion world model. Uses frame stacking for observation conditioning and adaptive group norm for action conditioning.

Parameters:
  • obs_channels (int)

  • num_conditioning_frames (int)

  • base_channels (int)

  • channel_multipliers (Tuple[int, ...])

  • num_res_blocks (int)

  • cond_dim (int)

  • action_dim (int)

forward(x, t, obs_history, actions)[source]#

Forward pass of the diffusion model.

Parameters:
  • x (Tensor) – Noised observation at timestep t [B, C, H, W]

  • t (Tensor) – Diffusion timestep [B]

  • obs_history (Tensor) – Past observations for conditioning [B, L, C, H, W]

  • actions (Tensor) – Past actions [B, L]

Returns:

Predicted clean observation [B, C, H, W]

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.EDMPreconditioner(sigma_data=0.5, p_mean=-0.4, p_std=1.2)[source]#

Bases: object

EDM preconditioner following Karras et al. (2022).

Parameters:
  • sigma_data (float)

  • p_mean (float)

  • p_std (float)

get_preconditioners(sigma)[source]#

Compute EDM preconditioners for given noise levels.

Returns:

Dictionary with c_skip, c_out, c_in, c_noise

Parameters:

sigma (Tensor)

Return type:

dict

sample_noise_level(batch_size, device)[source]#

Sample noise level from log-normal distribution.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tensor

denoise(model, x, sigma, **kwargs)[source]#

Apply EDM denoising with preconditioners.

Parameters:
  • model (Module) – Diffusion model

  • x (Tensor) – Noised input [B, C, H, W]

  • sigma (Tensor) – Noise level [B]

  • **kwargs (Any) – Additional conditioning (obs_history, actions)

Returns:

Denoised prediction [B, C, H, W]

Return type:

Tensor

class world_models.models.diffusion.diamond_diffusion.EulerSampler(sigma_min=0.002, sigma_max=80.0, rho=7, num_steps=3, edm_precond=None)[source]#

Bases: object

Euler method sampler for reverse diffusion.

Parameters:
  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • num_steps (int)

  • edm_precond (EDMPreconditioner | None)

sample(model, shape, device, obs_history=None, actions=None)[source]#

Generate samples using Euler method.

Parameters:
  • model (Module) – Diffusion model

  • shape (Tuple[int, ...]) – Output shape [B, C, H, W]

  • device (device) – Device to run on

  • obs_history (Tensor | None) – Conditioning observations [B, L, C, H, W]

  • actions (Tensor | None) – Conditioning actions [B, L]

Returns:

Generated samples [B, C, H, W]

Return type:

Tensor

class world_models.models.diffusion.reward_termination.ConvBlock(in_channels, out_channels, cond_dim, stride=2)[source]#

Bases: Module

Convolutional block with adaptive group normalization.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • cond_dim (int)

  • stride (int)

forward(x, cond)[source]#
Parameters:
  • x (Tensor)

  • cond (Tensor)

Return type:

Tensor

class world_models.models.diffusion.reward_termination.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#

Bases: Module

Reward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.

Parameters:
  • obs_channels (int) – Number of observation channels (3 for RGB)

  • action_dim (int) – Number of possible actions

  • channels (Tuple[int, ...]) – List of channel sizes for conv blocks

  • lstm_dim (int) – LSTM hidden dimension

  • cond_dim (int) – Conditioning dimension for adaptive norm

forward(obs, actions, hidden_state=None)[source]#

Forward pass of reward/termination model.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • actions (Tensor) – Actions [B, T]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states

Return type:

reward_logits

predict(obs, actions, hidden_state=None)[source]#

Predict reward and termination for a single step.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • actions (Tensor) – Single action [B]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states

Return type:

reward

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class world_models.models.diffusion.reward_termination.RewardTerminationLoss[source]#

Bases: Module

Loss function for reward and termination prediction.

forward(reward_logits, termination_logits, rewards, terminated)[source]#

Compute loss for reward and termination predictions.

Parameters:
  • reward_logits (Tensor) – [B, T, 3]

  • termination_logits (Tensor) – [B, T, 2]

  • rewards (Tensor) – Rewards as class indices [B, T] (values -1, 0, 1 mapped to 0, 1, 2)

  • terminated (Tensor) – Termination flags [B, T]

Returns:

total_loss, reward_loss, termination_loss

Return type:

Tuple[Tensor, Tensor, Tensor]

class world_models.models.diffusion.actor_critic.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#

Bases: Module

Actor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.

Parameters:
  • obs_channels (int)

  • action_dim (int)

  • channels (Tuple[int, ...])

  • lstm_dim (int)

forward(obs, hidden_state=None)[source]#

Forward pass of actor-critic network.

Parameters:
  • obs (Tensor) – Observations [B, T, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

Returns:

[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)

Return type:

policy_logits

get_action(obs, hidden_state=None, deterministic=False)[source]#

Get action from a single observation.

Parameters:
  • obs (Tensor) – Single observation [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states

  • deterministic (bool) – If True, take argmax; else sample

Returns:

Selected action [B] hidden_state: (h, c)

Return type:

action

get_actions(obs, hidden_state=None, deterministic=False)[source]#

Batched version of get_action.

Parameters:
  • obs (Tensor) – Tensor of shape [B, C, H, W]

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size

  • deterministic (bool) – If True, take argmax; else sample from policy

Returns:

LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple

Return type:

actions

get_value(obs, hidden_state=None)[source]#

Get value for a single observation.

Parameters:
  • obs (Tensor)

  • hidden_state (Tuple[Tensor, Tensor] | None)

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor] | None]

init_hidden(batch_size, device)[source]#

Initialize LSTM hidden states.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_hidden_size()[source]#

Get LSTM hidden size.

Return type:

int

class world_models.models.diffusion.actor_critic.RLLoss(discount_factor=0.985, lambda_returns=0.95, entropy_weight=0.001)[source]#

Bases: Module

RL loss functions for DIAMOND. Implements REINFORCE with value baseline and λ-returns.

Parameters:
  • discount_factor (float)

  • lambda_returns (float)

  • entropy_weight (float)

compute_lambda_returns(rewards, values, dones)[source]#

Compute λ-returns.

Parameters:
  • rewards (Tensor) – [B, T]

  • values (Tensor) – [B, T+1]

  • dones (Tensor) – [B, T]

Returns:

[B, T]

Return type:

lambda_returns

policy_loss(policy_logits, actions, lambda_returns, values)[source]#

Compute policy loss with REINFORCE and entropy regularization.

Parameters:
  • policy_logits (Tensor) – [B, T, A]

  • actions (Tensor) – [B, T]

  • lambda_returns (Tensor) – [B, T]

  • values (Tensor) – [B, T+1]

Returns:

scalar

Return type:

policy_loss

value_loss(values, lambda_returns)[source]#

Compute value loss (MSE between value and lambda returns).

Parameters:
  • values (Tensor)

  • lambda_returns (Tensor)

Return type:

Tensor

Vision, tokenization, and layers#

Key classes: ConvEncoder, ConvDecoder, DenseDecoder, ActionDecoder, CNNEncoder, CNNDecoder, IRISEncoder, IRISDecoder, DiscreteAutoencoder, VectorQuantizer, VectorQuantizerEMA, VideoTokenizer, MultiHeadSelfAttention, and STTransformer.

Convolutional Variational Autoencoder (ConvVAE) implementation.

This module provides the ConvVAE model architecture for encoding and decoding images in the World Models framework. The VAE uses a convolutional encoder and decoder with a variational latent space.

class world_models.vision.VAE.ConvVAE.ConvVAEEncoder(img_channels, latent_size)[source]#

Bases: Module

Convolutional encoder for VAE.

This encoder takes images and produces the parameters (mean and log variance) of a Gaussian distribution in the latent space.

Parameters:
  • img_channels (int)

  • latent_size (int)

latent_size#

Dimensionality of the latent space.

img_channels#

Number of input image channels.

Example

>>> encoder = ConvVAEEncoder(img_channels=3, latent_size=32)
>>> mu, logsigma = encoder(images)
forward(x)[source]#

Encode images to latent distribution parameters.

Parameters:

x (Tensor) – Input tensor of shape (batch, channels, height, width).

Returns:

  • mu: Mean of the latent distribution

  • logsigma: Log variance of the latent distribution

Return type:

Tuple of (mu, logsigma) where

class world_models.vision.VAE.ConvVAE.ConvVAEDecoder(latent_size, img_channels)[source]#

Bases: Module

Convolutional decoder for VAE.

This decoder takes latent vectors and reconstructs images.

Parameters:
  • latent_size (int)

  • img_channels (int)

latent_size#

Dimensionality of the input latent space.

img_channels#

Number of output image channels.

forward(z)[source]#

Decode latent vectors to images.

Parameters:

z (Tensor) – Latent vector of shape (batch, latent_size).

Returns:

Reconstructed image tensor of shape (batch, channels, height, width).

Return type:

Tensor

class world_models.vision.VAE.ConvVAE.ConvVAE(img_channels, latent_size)[source]#

Bases: Module

Convolutional Variational Autoencoder.

The ConvVAE is a generative model that encodes images into a latent distribution and reconstructs them. It uses the reparameterization trick to enable backpropagation through the sampling process.

Parameters:
  • img_channels (int)

  • latent_size (int)

encoder#

ConvVAEEncoder that encodes images to latent parameters.

decoder#

ConvVAEDecoder that decodes latent vectors to images.

Example

>>> vae = ConvVAE(img_channels=3, latent_size=32)
>>> recon_x, mu, logsigma = vae(images)
>>> # Training loss combines reconstruction and KL divergence
forward(x)[source]#

Encode and decode an image.

Parameters:

x (Tensor) – Input image tensor of shape (batch, channels, height, width).

Returns:

  • recon_x: Reconstructed image

  • mu: Mean of latent distribution

  • logsigma: Log variance of latent distribution

Return type:

Tuple of (recon_x, mu, logsigma)

class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).

  • Input: (B, C, H, W) raw images, values in [-0.5, 0.5]

  • Process: 4 convolutional layers with stride 2, halving spatial dimensions

  • Output: (B, embed_size) compact representation

The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.

Usage with Dreamer:

encoder = ConvEncoder(
    input_shape=(3, 64, 64),  # RGB 64x64 images
    embed_size=256,           # RSSM observation embedding size
    activation='relu'         # Activation function
)
obs_embedding = encoder(observation)  # (B, 256)
Parameters:
  • input_shape (tuple) – Tuple (C, H, W) for input images, typically (3, 64, 64)

  • embed_size (int) – Output embedding dimension, typically 256 or 1024

  • activation (str) – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)

  • depth (int) – Base channel depth for first layer (default 32)

forward(inputs)[source]#
Parameters:

inputs (Tensor)

Return type:

Tensor

class world_models.vision.dreamer_decoder.TanhBijector[source]#

Bases: Transform

Bijective tanh transform for squashing Gaussian distributions to [-1, 1].

This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:

  1. Bijective mapping: tanh is invertible (with atanh as inverse)

  2. Stable log-det Jacobian: Computable for gradient-based training

  3. Clipped actions: During inference, actions are naturally bounded

  • Forward: y = tanh(x)

  • Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y))

  • Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))

Usage with Dreamer ActionDecoder:

dist = TransformedDistribution(
    Normal(mean, std),
    TanhBijector()
)
action = dist.sample()  # Bounded to [-1, 1]
Reference:

Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.

property sign: int#
atanh(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

log_abs_det_jacobian(x, y)[source]#
Parameters:
  • x (Tensor)

  • y (Tensor)

Return type:

Tensor

class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional decoder for reconstructing observations from latent states.

Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.

  • Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter)

  • Process: Dense projection + 4 transposed convolutions (upsampling 2x each)

  • Output: Independent Normal distribution over observation pixels

The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.

Returns torch.distributions.Independent(Normal(mean, std), len(shape)) allowing log_prob(observation) computation for reconstruction loss.

Usage in Dreamer world model:

decoder = ConvDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(3, 64, 64),  # RGB images
    activation='relu'
)
obs_dist = decoder(latent_features)  # Returns distribution
log_prob = obs_dist.log_prob(target_observation)

The reconstruction loss is -log_prob(observation), which encourages the RSSM to learn states that capture observation information.

Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (tuple[int, ...])

  • activation (str)

  • depth (int)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Independent

class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.

  • Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter)

  • Process: MLP with configurable layers and hidden units

  • Output: Predicted quantity with distribution (normal, binary, or raw)

Supports three output types: - 'normal': Gaussian distribution for regression (rewards, values) - 'binary': Bernoulli distribution for binary classification (discount) - 'none': Raw tensor for non-probabilistic outputs

Usage:

reward_decoder = DenseDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(1,),
    n_layers=2,
    units=400,
    activation='elu',
    dist='normal'
)
reward_dist = reward_decoder(latent_features)
reward_loss = -reward_dist.log_prob(target_reward)

For discount prediction (binary):

discount_decoder = DenseDecoder(
    stoch_size=30,
    deter_size=200,
    output_shape=(1,),
    n_layers=2,
    units=400,
    activation='elu',
    dist='binary'
)
Parameters:
  • stoch_size (int)

  • deter_size (int)

  • output_shape (tuple[int, ...])

  • n_layers (int)

  • units (int)

  • activation (str)

  • dist (str)

  • num_buckets (int)

  • symlog_range (float)

forward(features)[source]#
Parameters:

features (Tensor)

Return type:

Any

class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

Parameters:
  • dist (Any)

  • samples (int)

property name: str#
mean()[source]#
Return type:

Tensor

mode()[source]#
Return type:

Tensor

entropy()[source]#
Return type:

Tensor

sample()[source]#
Return type:

Tensor

class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

Parameters:
  • action_size (int)

  • stoch_size (int)

  • deter_size (int)

  • n_layers (int)

  • units (int)

  • activation (str)

  • min_std (float)

  • init_std (float)

  • mean_scale (float)

forward(features, deter=False)[source]#
Parameters:
  • features (Tensor)

  • deter (bool)

Return type:

Tensor

add_exploration(action, action_noise=0.3)[source]#
Parameters:
  • action (Tensor)

  • action_noise (float)

Return type:

Tensor

class world_models.vision.planet_encoder.CNNEncoder(embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) encoder for processing image inputs.

Parameters:
  • embedding_size (int)

  • activation_function (str)

forward(observation)[source]#
Parameters:

observation (Tensor)

Return type:

Tensor

class world_models.vision.planet_decoder.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#

Bases: Module

A Convolutional Neural Network (CNN) decoder for reconstructing image outputs.

Parameters:
  • state_size (int)

  • latent_size (int)

  • embedding_size (int)

  • activation_function (str)

forward(latent, state)[source]#
Parameters:
  • latent (Tensor)

  • state (Tensor)

Return type:

Tensor

class world_models.vision.iris_encoder.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Encoder for IRIS discrete autoencoder.

Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.

Architecture:
  • 4 convolutional layers with residual blocks

  • Self-attention at 8x8 and 16x16 resolutions

  • Vector quantization to produce discrete tokens

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • in_channels (int)

  • base_channels (int)

  • num_residual_blocks (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Encode images to discrete tokens.

Parameters:

x (Tensor) – Input images (B, C, H, W) - should be 64x64

Returns:

Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

encode_to_indices(x)[source]#

Encode directly to token indices (for world model).

Parameters:

x (Tensor)

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode token indices to embeddings (for decoder).

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.vision.iris_encoder.ResidualBlock(channels)[source]#

Bases: Module

Residual block for encoder.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_encoder.SelfAttentionBlock(channels)[source]#

Bases: Module

Self-attention block for encoder.

Applies spatial self-attention to capture long-range dependencies.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#

Bases: Module

CNN Decoder for IRIS discrete autoencoder.

Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • base_channels (int)

  • out_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(z)[source]#

Decode tokens to images.

Parameters:

z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)

Returns:

Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)

Return type:

reconstructed

decode_from_embeddings(z_flat)[source]#

Decode flattened token embeddings to images.

Parameters:

z_flat (Tensor) – Flattened tokens (B, H*W, C) or (B, C, H, W)

Returns:

Reconstructed images

Return type:

Tensor

decode_from_indices(indices)[source]#

Decode discrete token indices into images.

Parameters:

indices (Tensor) – Tensor of shape (B, H, W) or (B, H*W) containing integer token indices in the range [0, vocab_size).

Returns:

Reconstructed images (B, C, H, W)

Return type:

Tensor

class world_models.vision.iris_decoder.UpsampleBlock(in_channels, mid_channels, out_channels)[source]#

Bases: Module

Upsampling block with optional residual connection.

Parameters:
  • in_channels (int)

  • mid_channels (int)

  • out_channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.ResidualBlock(channels)[source]#

Bases: Module

Residual block for decoder.

Parameters:

channels (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.vision.iris_decoder.DiscreteAutoencoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, base_channels=64, frame_shape=(3, 64, 64))[source]#

Bases: Module

Complete Discrete Autoencoder combining encoder and decoder.

Used for training the VQVAE component of IRIS.

Parameters:
  • vocab_size (int)

  • tokens_per_frame (int)

  • embedding_dim (int)

  • base_channels (int)

  • frame_shape (Tuple[int, int, int])

forward(x)[source]#

Full encode-decode forward pass.

Parameters:

x (Tensor) – Input images (B, C, H, W)

Returns:

Reconstructed images indices: Token indices (B, H’, W’) loss_dict: Dictionary with loss components

Return type:

reconstruction

encode(x)[source]#

Encode to token indices.

Parameters:

x (Tensor)

Return type:

Tensor

decode(indices)[source]#

Decode token indices to images.

Parameters:

indices (Tensor)

Return type:

Tensor

class world_models.vision.vq_layer.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#

Bases: Module

Vector Quantizer for discrete autoencoder.

Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)

Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

forward(z)[source]#

Quantize the input latents.

Parameters:

z (Tensor) – Input tensor of shape (B, C, H, W) or (B, C)

Returns:

Quantized tensor (same shape as input) indices: Token indices for each position (B, H, W) or (B,) loss: Dictionary containing VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices back to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.vision.vq_layer.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#

Bases: Module

Vector Quantizer with Exponential Moving Average updates.

Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • ema_decay (float)

  • epsilon (float)

forward(z)[source]#

Quantize with EMA updates.

Parameters:

z (Tensor)

Return type:

Tuple[Tensor, Tensor, dict[str, Tensor]]

decode_indices(indices)[source]#

Decode token indices to embeddings.

Parameters:

indices (Tensor) – Token indices (B, H, W) or (B,)

Returns:

Embeddings (B, C, H, W) or (B, C)

Return type:

Tensor

class world_models.vision.video_tokenizer.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#

Bases: Module

Video Tokenizer using VQ-VAE with Spatiotemporal Transformer.

This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.

The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.

Architecture

  1. Patch Embedding: Convert (B, C, T, H, W) video to patch tokens

  2. Encoder ST-Transformer: Process spatial-temporal patches

  3. Vector Quantization: Discretize continuous embeddings to codebook entries

  4. Decoder ST-Transformer: Reconstruct video from quantized tokens

  5. Patch Unembedding: Convert tokens back to video frames

Key Features

  • Causal processing: Each frame’s encoding only uses previous frames

  • Discrete tokens: Enables autoregressive prediction with latent actions

  • Memory efficient: Uses ST-Transformer instead of full ViT to reduce complexity

Usage with Genie:

tokenizer = VideoTokenizer(
    num_frames=16,
    image_size=64,
    patch_size=4,
    vocab_size=1024,
    embedding_dim=32
)
reconstructed, indices, loss_dict = tokenizer(video_frames)

# For discrete token input to dynamics model:
token_embeddings = tokenizer.decode_indices(indices)

The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings - Commitment loss: Penalizes encoder outputs drifting from codebook

Reference:

Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • use_ema (bool)

  • ema_decay (float)

encode(x)[source]#

Encode video to discrete tokens.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components

Return type:

z_q

decode_indices(indices)[source]#

Decode token indices to embeddings for video frames.

Parameters:

indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’ x W’

Returns:

Quantized embeddings (B, T, H’, W’, embedding_dim)

Return type:

z_q

decode(z_q)[source]#

Decode discrete tokens to video frames.

Parameters:

z_q (Tensor) – Quantized embeddings (B, T, H’, W’, embedding_dim)

Returns:

Reconstructed video (B, C, T, H, W)

Return type:

Tensor

forward(x)[source]#

Full forward pass with VQ-VAE objective.

Parameters:

x (Tensor) – Video tensor (B, C, T, H, W)

Returns:

Reconstructed video (B, C, T, H, W) indices: Token indices (B, T, H’, W’) loss_dict: Dictionary containing loss components

Return type:

reconstructed

world_models.vision.video_tokenizer.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#

Factory function to create a Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

Return type:

VideoTokenizer

Blocks sub-module - Transformer blocks and attention mechanisms.

Exported Components:
Transformers:
  • STTransformer: Spatiotemporal Transformer for video processing

  • STSpatialAttention: Spatial attention layer

  • STTemporalAttention: Temporal attention layer

  • STTransformerBlock: Combined spatiotemporal transformer block

  • STBlock: Backwards-compatible alias for STTransformerBlock

Attention:
  • MultiHeadSelfAttention: Multi-head self-attention

  • MultiHeadAttention: Backwards-compatible alias for MultiHeadSelfAttention

  • Attention: Attention mechanism

Normalization:
  • RMSNorm: Root Mean Square Layer Normalization

  • AdaLNNormalization: Adaptive Layer Normalization

class world_models.blocks.mhsa.MultiHeadSelfAttention(d, n_heads=2)[source]#

Bases: Module

Multi-head scaled dot-product self-attention over sequence tokens.

This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.

Parameters:
  • d (int)

  • n_heads (int)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.blocks.st_transformer.STSpatialAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Spatial attention layer for spatiotemporal transformer.

Processes video tokens by attending over spatial positions (H*W) within each time step independently. Captures within-frame spatial relationships.

  • Input: (B, T, N, C) – B batches, T time steps, N spatial positions (H*W), C channels

  • Output: (B, T, N, C) – Same shape, spatially attended features

Architecture

  • QKV projection: Linear(dim, dim*3)

  • Reshape to multi-head attention format

  • Fused scaled dot-product attention (FlashAttention on supported GPUs)

  • Output projection

Applied to video tokens of shape (B, T, N, C) to capture within-frame spatial structure (e.g., object positions).

Parameters:
  • dim (int)

  • num_heads (int)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • attn_drop (float)

  • proj_drop (float)

forward(x)[source]#
Parameters:

x (Tensor) – (B, T, N, C) where T is temporal dim, N is spatial dim (H*W)

Returns:

(B, T, N, C)

Return type:

Tensor

class world_models.blocks.st_transformer.STTemporalAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Temporal attention layer with causal masking for spatiotemporal transformer.

Processes video tokens by attending over time steps (T) across all spatial positions. Uses causal masking to ensure each frame only attends to previous frames (important for autoregressive video generation).

  • Input: (B, T, N, C) – B batches, T time steps, N spatial positions, C channels

  • Output: (B, T, N, C) – Same shape, temporally attended features

Causal masking

  • Frame t can only attend to frames 0…t-1

  • Prevents information leakage from future frames

  • Essential for autoregressive video generation models

Applied after STSpatialAttention to model temporal dynamics in the Genie VideoTokenizer.

Parameters:
  • dim (int)

  • num_heads (int)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • attn_drop (float)

  • proj_drop (float)

causal_mask: Tensor#
forward(x, causal=True)[source]#
Parameters:
  • x (Tensor) – (B, T, N, C) where T is temporal dim, N is spatial dim (H*W)

  • causal (bool) – whether to apply causal masking

Returns:

(B, T, N, C)

Return type:

Tensor

class world_models.blocks.st_transformer.STMLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#

Bases: Module

MLP for ST-Transformer block.

Parameters:
  • in_features (int)

  • hidden_features (int | None)

  • out_features (int | None)

  • act_layer (type[Module])

  • drop (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.blocks.st_transformer.STTransformerBlock(dim, num_heads=8, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Combined spatiotemporal transformer block with interleaved attention.

A single block applies:

  1. Spatial attention (within each time frame)

  2. Temporal attention (across frames with causal mask)

  3. MLP projection

The order is: x -> + SpatialAttn -> + TemporalAttn -> + MLP -> x

This interleaved design captures both spatial structure and temporal dynamics efficiently, used in Genie’s VideoTokenizer and DynamicsModel.

Parameters:
  • dim (int) – Feature dimension (must match patch embedding dimension)

  • num_heads (int) – Number of attention heads

  • mlp_ratio (float) – MLP hidden dim = dim * mlp_ratio

  • drop (float) – Dropout rates

  • attn_drop (float) – Dropout rates

  • drop_path (float) – Stochastic depth rate for drop path regularization

  • norm_layer (type[Module]) – Normalization layer class (default: nn.LayerNorm)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • act_layer (type[Module])

Usage in Genie:

# VideoTokenizer encoder (12 layers)
encoder = STTransformer(
    num_frames=16,
    num_patches_per_frame=256,
    dim=512,
    depth=12,
    num_heads=16
)
encoded = encoder(tokens)  # (B, T*N, C)

# Dynamics model decoder (24 layers)
decoder = STTransformer(
    num_frames=16,
    num_patches_per_frame=256,
    dim=1024,
    depth=24,
    num_heads=16
)
decoded = decoder(tokens)
forward(x)[source]#
Parameters:

x (Tensor) – (B, T, N, C) or (B, T*H*W, C)

Returns:

Same shape as input

Return type:

Tensor

class world_models.blocks.st_transformer.DropPath(drop_prob=0.0)[source]#

Bases: Module

Drop paths (Stochastic Depth) per sample.

Parameters:

drop_prob (float)

forward(x)[source]#
Parameters:

x (Tensor)

Return type:

Tensor

class world_models.blocks.st_transformer.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, gradient_checkpointing=False)[source]#

Bases: Module

Spatiotemporal Transformer for video modeling.

Contains L spatiotemporal blocks with interleaved spatial and temporal attention.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • qk_scale (float | None)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • norm_layer (type[Module])

  • gradient_checkpointing (bool)

forward(x)[source]#
Parameters:

x (Tensor) – (B, T*N, C) where T is num_frames, N is num_patches_per_frame

Returns:

(B, T*N, C)

Return type:

Tensor

world_models.blocks.st_transformer.create_st_transformer(num_frames=16, patch_size=4, img_size=64, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=False)[source]#

Factory function to create an ST-Transformer.

Parameters:
  • num_frames (int)

  • patch_size (int)

  • img_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

  • gradient_checkpointing (bool)

Return type:

STTransformer

Configuration objects#

Lazy config exports.

Configuration modules can have optional training dependencies, so the package initializer avoids importing every config eagerly.

class world_models.configs.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#

Bases: SerializableConfigMixin

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

Parameters:
  • env_backend (str)

  • env (str)

  • env_instance (Any)

  • image_size (tuple[int, int])

  • gym_render_mode (str)

  • dmlab_action_repeat (int)

  • dmlab_action_set (Any)

  • dmlab_observations (Any)

  • dmlab_config (Any)

  • dmlab_renderer (str)

  • procgen_distribution_mode (str)

  • procgen_num_levels (int)

  • procgen_start_level (Any)

  • mujoco_xml_path (Any)

  • mujoco_xml_string (Any)

  • mujoco_binary_path (Any)

  • mujoco_camera (Any)

  • mujoco_frame_skip (int)

  • mujoco_reset_noise_scale (float)

  • brax_backend (str)

  • brax_jit (bool)

  • brax_auto_reset (bool)

  • brax_suppress_warp_warnings (bool)

  • unity_file_name (Any)

  • unity_behavior_name (Any)

  • unity_worker_id (int)

  • unity_base_port (int)

  • unity_no_graphics (bool)

  • unity_time_scale (float)

  • unity_quality_level (int)

  • algo (str)

  • exp_name (str)

  • train (bool)

  • evaluate (bool)

  • seed (int)

  • no_gpu (bool)

  • max_episode_length (int)

  • buffer_size (int)

  • time_limit (int)

  • cnn_activation_function (str)

  • dense_activation_function (str)

  • obs_embed_size (int)

  • num_units (int)

  • deter_size (int)

  • stoch_size (int)

  • action_repeat (int)

  • action_noise (float)

  • total_steps (int)

  • seed_steps (int)

  • update_steps (int)

  • collect_steps (int)

  • batch_size (int)

  • train_seq_len (int)

  • imagine_horizon (int)

  • use_disc_model (bool)

  • free_nats (float)

  • discount (float)

  • td_lambda (float)

  • kl_loss_coeff (float)

  • kl_alpha (float)

  • disc_loss_coeff (float)

  • num_buckets (int)

  • symlog_range (float)

  • model_learning_rate (float)

  • actor_learning_rate (float)

  • value_learning_rate (float)

  • adam_epsilon (float)

  • grad_clip_norm (float)

  • use_amp (bool)

  • test (bool)

  • test_interval (int)

  • test_episodes (int)

  • scalar_freq (int)

  • log_video_freq (int)

  • max_videos_to_save (int)

  • video_format (str)

  • video_fps (int)

  • checkpoint_interval (int)

  • checkpoint_path (str)

  • restore (bool)

  • experience_replay (str)

  • render (bool)

  • enable_wandb (bool)

  • wandb_api_key (str)

  • wandb_project (str)

  • wandb_entity (str)

  • log_dir (str)

  • logdir (Any)

  • data_dir (Any)

  • log_level (str)

  • log_file (Any)

  • enable_tensorboard (bool)

  • enable_console_metrics (bool)

  • enable_jsonl (bool)

  • jsonl_filename (str)

  • log_system_stats_freq (int)

  • detect_anomaly (bool)

env_backend: str = 'dmc'#
env: str = 'walker-walk'#
env_instance: Any = None#
image_size: tuple[int, int] = (64, 64)#
gym_render_mode: str = 'rgb_array'#
dmlab_action_repeat: int = 4#
dmlab_action_set: Any = None#
dmlab_observations: Any = None#
dmlab_config: Any = None#
dmlab_renderer: str = 'hardware'#
procgen_distribution_mode: str = 'easy'#
procgen_num_levels: int = 0#
procgen_start_level: Any = None#
mujoco_xml_path: Any = None#
mujoco_xml_string: Any = None#
mujoco_binary_path: Any = None#
mujoco_camera: Any = None#
mujoco_frame_skip: int = 1#
mujoco_reset_noise_scale: float = 0.0#
brax_backend: str = 'generalized'#
brax_jit: bool = True#
brax_auto_reset: bool = False#
brax_suppress_warp_warnings: bool = True#
unity_file_name: Any = None#
unity_behavior_name: Any = None#
unity_worker_id: int = 0#
unity_base_port: int = 5005#
unity_no_graphics: bool = True#
unity_time_scale: float = 20.0#
unity_quality_level: int = 1#
algo: str = 'Dreamerv1'#
exp_name: str = 'lr1e-3'#
train: bool = True#
evaluate: bool = False#
seed: int = 1#
no_gpu: bool = False#
max_episode_length: int = 1000#
buffer_size: int = 800000#
time_limit: int = 1000#
cnn_activation_function: str = 'relu'#
dense_activation_function: str = 'elu'#
obs_embed_size: int = 1024#
num_units: int = 400#
deter_size: int = 200#
stoch_size: int = 30#
action_repeat: int = 2#
action_noise: float = 0.3#
total_steps: int = 5000000#
seed_steps: int = 5000#
update_steps: int = 100#
collect_steps: int = 1000#
batch_size: int = 50#
train_seq_len: int = 50#
imagine_horizon: int = 15#
use_disc_model: bool = False#
free_nats: float = 3.0#
discount: float = 0.99#
td_lambda: float = 0.95#
kl_loss_coeff: float = 1.0#
kl_alpha: float = 0.8#
disc_loss_coeff: float = 10.0#
num_buckets: int = 255#
symlog_range: float = 10.0#
model_learning_rate: float = 0.0006#
actor_learning_rate: float = 8e-05#
value_learning_rate: float = 8e-05#
adam_epsilon: float = 1e-07#
grad_clip_norm: float = 100.0#
use_amp: bool = True#
test: bool = False#
test_interval: int = 10000#
test_episodes: int = 10#
scalar_freq: int = 1000#
log_video_freq: int = -1#
max_videos_to_save: int = 2#
video_format: str = 'gif'#
video_fps: int = 20#
checkpoint_interval: int = 10000#
checkpoint_path: str = ''#
restore: bool = False#
experience_replay: str = ''#
render: bool = False#
enable_wandb: bool = False#
wandb_api_key: str = ''#
wandb_project: str = 'torchwm'#
wandb_entity: str = ''#
log_dir: str = 'runs'#
logdir: Any = None#
data_dir: Any = None#
log_level: str = 'INFO'#
log_file: Any = None#
enable_tensorboard: bool = False#
enable_console_metrics: bool = True#
enable_jsonl: bool = True#
jsonl_filename: str = 'metrics.jsonl'#
log_system_stats_freq: int = 1000#
detect_anomaly: bool = False#
class world_models.configs.JEPAConfig[source]#

Bases: SerializableConfigMixin

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

classmethod from_dict(values)[source]#

Load flat field values or the nested trainer dictionary.

Parameters:

values (Dict[str, Any])

Return type:

JEPAConfig

to_train_dict()[source]#

Return the nested dictionary expected by train_jepa.

Return type:

Dict[str, Dict[str, Any]]

to_nested_dict()[source]#

Backward-compatible alias for the nested JEPA trainer dictionary.

Return type:

Dict[str, Dict[str, Any]]

class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: SerializableConfigMixin

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via __getattr__ and get_dit_config().

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.configs.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Both UPPER_CASE and snake_case override keys are accepted.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)

Parameters:

overrides (Any)

Return type:

DiTConfig

class world_models.configs.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: SerializableConfigMixin

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • data_loader_num_workers (int)

  • pin_memory (bool)

  • persistent_workers (bool)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • use_amp (bool)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
data_loader_num_workers: int = 4#
pin_memory: bool = True#
persistent_workers: bool = True#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
use_amp: bool = True#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#
class world_models.configs.IRISConfig[source]#

Bases: SerializableConfigMixin

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
Return type:

tuple

get_autoencoder_config()[source]#
Return type:

dict

get_transformer_config()[source]#
Return type:

dict

get_rl_config()[source]#
Return type:

dict

class world_models.configs.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.configs.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.configs.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.configs.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: SerializableConfigMixin

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class world_models.configs.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: SerializableConfigMixin

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class world_models.configs.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#

Configuration classes for World Models training.

class world_models.configs.wm_config.WMVAEConfig(config_dict)[source]#

Bases: object

Configuration class for Variational Autoencoder (VAE) training.

This class manages all hyperparameters and settings for training a ConvVAE model on observation data. It provides validation and dictionary conversion utilities.

Parameters:

config_dict (dict)

height#

Height of input images (pixels).

width#

Width of input images (pixels).

device#

Device to train on (‘cpu’ or ‘cuda’).

train_batch_size#

Number of samples per training batch.

num_epochs#

Total number of training epochs.

latent_size#

Dimensionality of the VAE latent space.

data_dir#

Path to the dataset directory.

learning_rate#

Initial learning rate for optimizer.

logdir#

Directory for saving logs and checkpoints.

noreload#

If True, skip loading existing checkpoints.

nosamples#

If True, skip saving sample images during training.

scheduler_patience#

Epochs to wait before reducing learning rate.

scheduler_factor#

Multiplicative factor for learning rate reduction.

early_stopping_patience#

Epochs to wait before early stopping.

sample_interval#

Epoch interval for saving sample images.

extra#

Dictionary for additional custom parameters.

Example

>>> config = WMVAEConfig({
...     'height': 64,
...     'width': 64,
...     'latent_size': 32,
...     'logdir': 'results',
... })
>>> config.latent_size
32
class world_models.configs.wm_config.WMMDNRNNConfig(config_dict)[source]#

Bases: object

Configuration class for Mixture Density Recurrent Neural Network (MDRNN) training.

This class manages all hyperparameters and settings for training an MDRNN model on sequence data. It provides validation and dictionary conversion utilities.

Parameters:

config_dict (dict)

latent_size#

Dimensionality of the latent space from VAE.

action_size#

Dimensionality of action space.

hidden_size#

Number of hidden units in RNN.

gmm_components#

Number of Gaussian mixture components.

device#

Device to train on (‘cpu’ or ‘cuda’).

batch_size#

Number of sequences per batch.

seq_len#

Length of each sequence.

num_epochs#

Total number of training epochs.

data_dir#

Path to the dataset directory.

learning_rate#

Initial learning rate for optimizer.

logdir#

Directory for saving logs and checkpoints.

noreload#

If True, skip loading existing checkpoints.

include_reward#

If True, include reward prediction in loss.

scheduler_patience#

Epochs to wait before reducing learning rate.

scheduler_factor#

Multiplicative factor for learning rate reduction.

early_stopping_patience#

Epochs to wait before early stopping.

extra#

Dictionary for additional custom parameters.

Example

>>> config = WMMDNRNNConfig({
...     'latent_size': 32,
...     'action_size': 3,
...     'hidden_size': 256,
...     'gmm_components': 5,
... })
>>> config.hidden_size
256
to_dict()[source]#

Convert configuration to dictionary.

Returns:

Dictionary containing all configuration parameters.

Return type:

Dict[str, Any]

class world_models.configs.wm_config.WMControllerConfig(config_dict)[source]#

Bases: object

Configuration class for Controller training with CMA-ES.

This class manages hyperparameters for training a linear controller using Covariance Matrix Adaptation Evolution Strategy (CMA-ES).

Parameters:

config_dict (dict)

latent_size#

Dimensionality of latent state from VAE.

hidden_size#

Dimensionality of RSSM hidden state.

action_size#

Dimensionality of action space.

logdir#

Directory for saving logs and checkpoints.

n_samples#

Number of samples used to obtain return estimate.

pop_size#

Population size for CMA-ES.

target_return#

Stop once the return gets above this threshold.

max_workers#

Maximum number of workers for parallel evaluation.

display#

If True, show progress bars during training.

time_limit#

Maximum steps per episode.

to_dict()[source]#

Convert configuration to dictionary.

Return type:

Dict[str, Any]

class world_models.configs.dreamer_config.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#

Bases: SerializableConfigMixin

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

Parameters:
  • env_backend (str)

  • env (str)

  • env_instance (Any)

  • image_size (tuple[int, int])

  • gym_render_mode (str)

  • dmlab_action_repeat (int)

  • dmlab_action_set (Any)

  • dmlab_observations (Any)

  • dmlab_config (Any)

  • dmlab_renderer (str)

  • procgen_distribution_mode (str)

  • procgen_num_levels (int)

  • procgen_start_level (Any)

  • mujoco_xml_path (Any)

  • mujoco_xml_string (Any)

  • mujoco_binary_path (Any)

  • mujoco_camera (Any)

  • mujoco_frame_skip (int)

  • mujoco_reset_noise_scale (float)

  • brax_backend (str)

  • brax_jit (bool)

  • brax_auto_reset (bool)

  • brax_suppress_warp_warnings (bool)

  • unity_file_name (Any)

  • unity_behavior_name (Any)

  • unity_worker_id (int)

  • unity_base_port (int)

  • unity_no_graphics (bool)

  • unity_time_scale (float)

  • unity_quality_level (int)

  • algo (str)

  • exp_name (str)

  • train (bool)

  • evaluate (bool)

  • seed (int)

  • no_gpu (bool)

  • max_episode_length (int)

  • buffer_size (int)

  • time_limit (int)

  • cnn_activation_function (str)

  • dense_activation_function (str)

  • obs_embed_size (int)

  • num_units (int)

  • deter_size (int)

  • stoch_size (int)

  • action_repeat (int)

  • action_noise (float)

  • total_steps (int)

  • seed_steps (int)

  • update_steps (int)

  • collect_steps (int)

  • batch_size (int)

  • train_seq_len (int)

  • imagine_horizon (int)

  • use_disc_model (bool)

  • free_nats (float)

  • discount (float)

  • td_lambda (float)

  • kl_loss_coeff (float)

  • kl_alpha (float)

  • disc_loss_coeff (float)

  • num_buckets (int)

  • symlog_range (float)

  • model_learning_rate (float)

  • actor_learning_rate (float)

  • value_learning_rate (float)

  • adam_epsilon (float)

  • grad_clip_norm (float)

  • use_amp (bool)

  • test (bool)

  • test_interval (int)

  • test_episodes (int)

  • scalar_freq (int)

  • log_video_freq (int)

  • max_videos_to_save (int)

  • video_format (str)

  • video_fps (int)

  • checkpoint_interval (int)

  • checkpoint_path (str)

  • restore (bool)

  • experience_replay (str)

  • render (bool)

  • enable_wandb (bool)

  • wandb_api_key (str)

  • wandb_project (str)

  • wandb_entity (str)

  • log_dir (str)

  • logdir (Any)

  • data_dir (Any)

  • log_level (str)

  • log_file (Any)

  • enable_tensorboard (bool)

  • enable_console_metrics (bool)

  • enable_jsonl (bool)

  • jsonl_filename (str)

  • log_system_stats_freq (int)

  • detect_anomaly (bool)

env_backend: str = 'dmc'#
env: str = 'walker-walk'#
env_instance: Any = None#
image_size: tuple[int, int] = (64, 64)#
gym_render_mode: str = 'rgb_array'#
dmlab_action_repeat: int = 4#
dmlab_action_set: Any = None#
dmlab_observations: Any = None#
dmlab_config: Any = None#
dmlab_renderer: str = 'hardware'#
procgen_distribution_mode: str = 'easy'#
procgen_num_levels: int = 0#
procgen_start_level: Any = None#
mujoco_xml_path: Any = None#
mujoco_xml_string: Any = None#
mujoco_binary_path: Any = None#
mujoco_camera: Any = None#
mujoco_frame_skip: int = 1#
mujoco_reset_noise_scale: float = 0.0#
brax_backend: str = 'generalized'#
brax_jit: bool = True#
brax_auto_reset: bool = False#
brax_suppress_warp_warnings: bool = True#
unity_file_name: Any = None#
unity_behavior_name: Any = None#
unity_worker_id: int = 0#
unity_base_port: int = 5005#
unity_no_graphics: bool = True#
unity_time_scale: float = 20.0#
unity_quality_level: int = 1#
algo: str = 'Dreamerv1'#
exp_name: str = 'lr1e-3'#
train: bool = True#
evaluate: bool = False#
seed: int = 1#
no_gpu: bool = False#
max_episode_length: int = 1000#
buffer_size: int = 800000#
time_limit: int = 1000#
cnn_activation_function: str = 'relu'#
dense_activation_function: str = 'elu'#
obs_embed_size: int = 1024#
num_units: int = 400#
deter_size: int = 200#
stoch_size: int = 30#
action_repeat: int = 2#
action_noise: float = 0.3#
total_steps: int = 5000000#
seed_steps: int = 5000#
update_steps: int = 100#
collect_steps: int = 1000#
batch_size: int = 50#
train_seq_len: int = 50#
imagine_horizon: int = 15#
use_disc_model: bool = False#
free_nats: float = 3.0#
discount: float = 0.99#
td_lambda: float = 0.95#
kl_loss_coeff: float = 1.0#
kl_alpha: float = 0.8#
disc_loss_coeff: float = 10.0#
num_buckets: int = 255#
symlog_range: float = 10.0#
model_learning_rate: float = 0.0006#
actor_learning_rate: float = 8e-05#
value_learning_rate: float = 8e-05#
adam_epsilon: float = 1e-07#
grad_clip_norm: float = 100.0#
use_amp: bool = True#
test: bool = False#
test_interval: int = 10000#
test_episodes: int = 10#
scalar_freq: int = 1000#
log_video_freq: int = -1#
max_videos_to_save: int = 2#
video_format: str = 'gif'#
video_fps: int = 20#
checkpoint_interval: int = 10000#
checkpoint_path: str = ''#
restore: bool = False#
experience_replay: str = ''#
render: bool = False#
enable_wandb: bool = False#
wandb_api_key: str = ''#
wandb_project: str = 'torchwm'#
wandb_entity: str = ''#
log_dir: str = 'runs'#
logdir: Any = None#
data_dir: Any = None#
log_level: str = 'INFO'#
log_file: Any = None#
enable_tensorboard: bool = False#
enable_console_metrics: bool = True#
enable_jsonl: bool = True#
jsonl_filename: str = 'metrics.jsonl'#
log_system_stats_freq: int = 1000#
detect_anomaly: bool = False#
class world_models.configs.jepa_config.JEPAConfig[source]#

Bases: SerializableConfigMixin

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

classmethod from_dict(values)[source]#

Load flat field values or the nested trainer dictionary.

Parameters:

values (Dict[str, Any])

Return type:

JEPAConfig

to_train_dict()[source]#

Return the nested dictionary expected by train_jepa.

Return type:

Dict[str, Dict[str, Any]]

to_nested_dict()[source]#

Backward-compatible alias for the nested JEPA trainer dictionary.

Return type:

Dict[str, Dict[str, Any]]

class world_models.configs.iris_config.IRISConfig[source]#

Bases: SerializableConfigMixin

Configuration for IRIS (Imagination with auto-Regression over an Inner Speech)

Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.

get_frame_shape()[source]#
Return type:

tuple

get_autoencoder_config()[source]#
Return type:

dict

get_transformer_config()[source]#
Return type:

dict

get_rl_config()[source]#
Return type:

dict

class world_models.configs.genie_config.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Configuration for Genie model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 8#
image_size: int = 32#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 256#
action_encoder_depth: int = 4#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.configs.genie_config.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Small configuration for development/testing.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • tokenizer_num_heads (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_num_heads (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 256#
tokenizer_decoder_dim: int = 512#
tokenizer_encoder_depth: int = 4#
tokenizer_decoder_depth: int = 8#
tokenizer_num_heads: int = 8#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 512#
action_encoder_depth: int = 8#
action_num_heads: int = 8#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 2#
learning_rate: float = 0.0001#
weight_decay: float = 0.0001#
warmup_steps: int = 1000#
max_steps: int = 50000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.configs.genie_config.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Spatiotemporal Transformer.

Parameters:
  • num_frames (int)

  • num_patches_per_frame (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
num_patches_per_frame: int = 256#
dim: int = 768#
depth: int = 12#
num_heads: int = 12#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.configs.genie_config.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#

Bases: SerializableConfigMixin

Configuration for Video Tokenizer.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • decoder_dim (int)

  • encoder_depth (int)

  • decoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • use_ema (bool)

  • ema_decay (float)

  • commitment_weight (float)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 512#
decoder_dim: int = 1024#
encoder_depth: int = 12#
decoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 4#
vocab_size: int = 1024#
embedding_dim: int = 32#
use_ema: bool = False#
ema_decay: float = 0.99#
commitment_weight: float = 0.25#
class world_models.configs.genie_config.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#

Bases: SerializableConfigMixin

Configuration for Latent Action Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • encoder_dim (int)

  • encoder_depth (int)

  • num_heads (int)

  • patch_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • commitment_weight (float)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
encoder_dim: int = 1024#
encoder_depth: int = 20#
num_heads: int = 16#
patch_size: int = 16#
vocab_size: int = 8#
embedding_dim: int = 32#
commitment_weight: float = 1.0#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
class world_models.configs.genie_config.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#

Bases: SerializableConfigMixin

Configuration for Dynamics Model.

Parameters:
  • num_frames (int)

  • image_size (int)

  • vocab_size (int)

  • embedding_dim (int)

  • action_vocab_size (int)

  • dim (int)

  • depth (int)

  • num_heads (int)

  • patch_size (int)

  • mlp_ratio (float)

  • qkv_bias (bool)

  • drop_rate (float)

  • attn_drop_rate (float)

  • drop_path_rate (float)

num_frames: int = 16#
image_size: int = 64#
vocab_size: int = 1024#
embedding_dim: int = 32#
action_vocab_size: int = 8#
dim: int = 5120#
depth: int = 48#
num_heads: int = 36#
patch_size: int = 4#
mlp_ratio: float = 4.0#
qkv_bias: bool = True#
drop_rate: float = 0.0#
attn_drop_rate: float = 0.0#
drop_path_rate: float = 0.0#
class world_models.configs.dit_config.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: SerializableConfigMixin

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via __getattr__ and get_dit_config().

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.configs.dit_config.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Both UPPER_CASE and snake_case override keys are accepted.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)

Parameters:

overrides (Any)

Return type:

DiTConfig

world_models.configs.diamond_config.get_default_device()[source]#
Return type:

str

class world_models.configs.diamond_config.ModelPreset(diffusion_channels, diffusion_res_blocks, diffusion_cond_dim, reward_channels, reward_lstm_dim, actor_channels, actor_lstm_dim)[source]#

Bases: SerializableConfigMixin

Model architecture preset for different hardware tiers.

Parameters:
  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • reward_channels (List[int])

  • reward_lstm_dim (int)

  • actor_channels (List[int])

  • actor_lstm_dim (int)

diffusion_channels: List[int]#
diffusion_res_blocks: int#
diffusion_cond_dim: int#
reward_channels: List[int]#
reward_lstm_dim: int#
actor_channels: List[int]#
actor_lstm_dim: int#
class world_models.configs.diamond_config.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#

Bases: SerializableConfigMixin

Parameters:
  • preset (str | None)

  • game (str)

  • seed (int)

  • obs_size (int)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (List[int])

  • num_conditioning_frames (int)

  • diffusion_channels (List[int])

  • diffusion_res_blocks (int)

  • diffusion_cond_dim (int)

  • sigma_data (float)

  • sigma_min (float)

  • sigma_max (float)

  • rho (int)

  • p_mean (float)

  • p_std (float)

  • sampling_method (str)

  • num_sampling_steps (int)

  • reward_channels (List[int])

  • reward_res_blocks (int)

  • reward_cond_dim (int)

  • reward_lstm_dim (int)

  • burn_in_length (int)

  • actor_channels (List[int])

  • actor_res_blocks (int)

  • actor_lstm_dim (int)

  • num_epochs (int)

  • training_steps_per_epoch (int)

  • batch_size (int)

  • environment_steps_per_epoch (int)

  • epsilon_greedy (float)

  • data_loader_num_workers (int)

  • pin_memory (bool)

  • persistent_workers (bool)

  • imagination_horizon (int)

  • discount_factor (float)

  • entropy_weight (float)

  • lambda_returns (float)

  • learning_rate (float)

  • adam_epsilon (float)

  • weight_decay_diffusion (float)

  • weight_decay_reward (float)

  • weight_decay_actor (float)

  • use_amp (bool)

  • device (str)

  • log_interval (int)

  • eval_interval (int)

  • save_interval (int)

  • operator_state_dim (int)

  • operator_action_dim (int)

preset: str | None = None#
game: str = 'Breakout-v5'#
seed: int = 0#
obs_size: int = 64#
frameskip: int = 4#
max_noop: int = 30#
terminate_on_life_loss: bool = True#
reward_clip: List[int]#
num_conditioning_frames: int = 4#
diffusion_channels: List[int]#
diffusion_res_blocks: int = 2#
diffusion_cond_dim: int = 256#
sigma_data: float = 0.5#
sigma_min: float = 0.002#
sigma_max: float = 80.0#
rho: int = 7#
p_mean: float = -0.4#
p_std: float = 1.2#
sampling_method: str = 'euler'#
num_sampling_steps: int = 3#
reward_channels: List[int]#
reward_res_blocks: int = 2#
reward_cond_dim: int = 128#
reward_lstm_dim: int = 512#
burn_in_length: int = 4#
actor_channels: List[int]#
actor_res_blocks: int = 1#
actor_lstm_dim: int = 512#
num_epochs: int = 1000#
training_steps_per_epoch: int = 400#
batch_size: int = 32#
environment_steps_per_epoch: int = 100#
epsilon_greedy: float = 0.01#
data_loader_num_workers: int = 4#
pin_memory: bool = True#
persistent_workers: bool = True#
imagination_horizon: int = 15#
discount_factor: float = 0.985#
entropy_weight: float = 0.001#
lambda_returns: float = 0.95#
learning_rate: float = 0.0001#
adam_epsilon: float = 1e-08#
weight_decay_diffusion: float = 0.01#
weight_decay_reward: float = 0.01#
weight_decay_actor: float = 0.0#
use_amp: bool = True#
device: str#
log_interval: int = 10#
eval_interval: int = 50#
save_interval: int = 100#
operator_state_dim: int = 32#
operator_action_dim: int = 4#

Training entry points#

Complete World Model training pipeline for any Gym environment.

This script trains a complete World Model pipeline consisting of: 1. ConvVAE - Encodes observations into latent space 2. MDNRNN - Predicts future latent states given actions 3. Controller - Linear controller trained with CMA-ES

Usage:

python train_world_model.py –env CarRacing-v2 –data_dir ./data –logdir ./results python train_world_model.py –env BipedalWalker-v3 –action_size 4 # if env loading fails

The script will: 1. Generate rollout data (if not already present) 2. Train VAE 3. Train MDNRNN 4. Train Controller

world_models.training.train_world_model.generate_rollouts(data_dir, env_name, num_rollouts=1000, seq_len=1000, num_workers=8)[source]#

Generate random rollouts from the specified environment.

Parameters:
  • data_dir (str) – Directory to save rollout files

  • env_name (str) – Name of the Gym environment

  • num_rollouts (int) – Total number of rollouts to generate

  • seq_len (int) – Maximum length per rollout

  • num_workers (int) – Number of parallel workers

Return type:

None

world_models.training.train_world_model.run_training_pipeline(args, action_size)[source]#

Execute the complete World Model training pipeline.

Parameters:
  • args (Any)

  • action_size (int)

Return type:

None

world_models.training.train_world_model.test_trained_model(logdir, env_name, action_size, num_episodes=5)[source]#

Test the trained world model with controller in the environment.

Parameters:
  • logdir (str)

  • env_name (str)

  • action_size (int)

  • num_episodes (int)

Return type:

None

world_models.training.train_world_model.main()[source]#
Return type:

None

Training script for Convolutional Variational Autoencoder (ConvVAE).

This module provides functions to train a ConvVAE model on observation data for world model learning.

world_models.training.train_convvae.save_checkpoint(state, is_best, filename, best_filename)[source]#

Save model checkpoint.

Parameters:
  • state (dict) – Dictionary containing model state to save.

  • is_best (bool) – If True, also save as best checkpoint.

  • filename (str) – Path to save checkpoint.

  • best_filename (str) – Path to save best checkpoint.

Return type:

None

world_models.training.train_convvae.test_epoch(model, test_loader, device, loss_fn)[source]#

Run one epoch of validation.

Parameters:
  • model (ConvVAE) – The VAE model to evaluate.

  • test_loader (DataLoader) – DataLoader for test/validation data.

  • device (device) – Device to run evaluation on.

  • loss_fn (Any) – Loss function to use.

Returns:

Average test loss for the epoch.

Return type:

float

world_models.training.train_convvae.train_epoch(epoch, model, optimizer, train_loader, device, train_dataset, loss_fn, use_amp=False, scaler=None)[source]#

Run one epoch of training.

Parameters:
  • epoch (int) – Current epoch number.

  • model (Any) – The VAE model to train.

  • optimizer (Any) – Optimizer for training.

  • train_loader (Any) – DataLoader for training data.

  • device (Any) – Device to run training on.

  • train_dataset (Any) – Training dataset (used to load next buffer if applicable).

  • loss_fn (Any) – Loss function to use.

  • use_amp (bool) – Whether to use automatic mixed precision.

  • scaler (GradScaler | None) – GradScaler for mixed precision training.

Return type:

float

world_models.training.train_convvae.train_convae(config)[source]#

Train a Convolutional VAE model.

This function trains a ConvVAE on observation data using the provided configuration. It handles data loading, model initialization, training loop, checkpointing, and sample generation.

Parameters:

config (WMVAEConfig) – WMVAEConfig object containing all training hyperparameters.

Return type:

None

The training process includes:
  • Loading pretrained VAE if available (unless noreload is True)

  • Training for specified number of epochs

  • Validating after each epoch

  • Learning rate scheduling with ReduceLROnPlateau

  • Early stopping based on validation loss

  • Checkpointing best and current models

  • Generating sample images at specified intervals

Example

>>> config = WMVAEConfig({
...     'height': 64,
...     'width': 64,
...     'latent_size': 32,
...     'num_epochs': 100,
...     'logdir': 'results',
... })
>>> train_convae(config)

Training script for Mixture Density Recurrent Neural Network (MDRNN).

This module provides functions to train an MDRNN model for sequence prediction in world models. The MDRNN predicts future latent states using a Gaussian Mixture Model (GMM) based on current latent states and actions.

world_models.training.train_mdn_rnn.precompute_latents(vae_config, mdrnn_config)[source]#

Pre-compute and save VAE latents to disk for memory-efficient RNN training.

This function encodes all observations using the VAE and saves the latent representations to disk. This allows RNN training without keeping the VAE in GPU memory.

Parameters:
  • vae_config (WMVAEConfig) – WMVAEConfig for loading pretrained VAE.

  • mdrnn_config (WMMDNRNNConfig) – WMMDNRNNConfig containing latent_size and device settings.

Return type:

None

world_models.training.train_mdn_rnn.save_checkpoint(state, is_best, filename, best_filename)[source]#

Save model checkpoint.

Parameters:
  • state (Any) – Dictionary containing model state to save.

  • is_best (bool) – If True, also save as best checkpoint.

  • filename (str) – Path to save checkpoint.

  • best_filename (str) – Path to save best checkpoint.

Return type:

None

world_models.training.train_mdn_rnn.to_latent(vae, obs, next_obs, device, red_size=64)[source]#

Transform observations to latent space using VAE encoder.

This function encodes observations into the latent space using the VAE’s encoder network. It applies the reparameterization trick to sample from the learned latent distribution.

Parameters:
  • vae (ConvVAE) – Trained VAE model with encoder.

  • obs (Tensor) – Batch of current observations.

  • next_obs (Tensor) – Batch of next observations.

  • device (device) – Device to run encoding on.

  • red_size (int) – Target size for resizing images (default: 64).

Returns:

Tuple of (latent_obs, latent_next_obs) tensors in latent space.

Return type:

tuple[Tensor, Tensor]

world_models.training.train_mdn_rnn.get_loss(mdrnn, latent_obs, action, reward, terminal, latent_next_obs, include_reward, latent_size)[source]#

Compute MDRNN loss.

Computes the combined loss for the MDRNN model: - GMM loss for next latent state prediction - BCE loss for terminal state prediction - MSE loss for reward prediction (if include_reward is True)

Parameters:
  • mdrnn (MDRNN) – MDRNN model.

  • latent_obs (Tensor) – Current latent observations.

  • action (Tensor) – Actions taken.

  • reward (Tensor) – Rewards received.

  • terminal (Tensor) – Terminal state flags.

  • latent_next_obs (Tensor) – Next latent observations (target).

  • include_reward (bool) – Whether to include reward prediction in loss.

  • latent_size (int) – Size of latent space.

Returns:

Dictionary containing gmm, bce, mse, and total loss values.

Return type:

dict[str, Tensor]

world_models.training.train_mdn_rnn.data_pass(epoch, mdrnn, vae, train_loader, test_loader, optimizer, device, include_reward, test_every=10, epochs=1, use_amp=False, scaler=None, latent_size=32, max_seq_len=50, log_wandb=False, prev_val_loss=1000000.0, early_stop=None, lr_scheduler=None, batch_size=50, train=True)[source]#

Run one epoch of training or validation.

Parameters:
  • epoch (int) – Current epoch number.

  • mdrnn (Any) – MDRNN model.

  • vae (Any) – VAE model for encoding observations (None if using precomputed latents).

  • train_loader (Any) – Training data loader.

  • test_loader (Any) – Test/validation data loader.

  • optimizer (Any) – Optimizer (used only for training).

  • device (Any) – Device to run on.

  • include_reward (bool) – Whether to include reward in loss.

  • latent_size (int) – Size of latent space.

  • batch_size (int) – Batch size.

  • train (bool) – If True, run training pass; otherwise run validation.

  • use_amp (bool) – If True, use automatic mixed precision.

  • scaler (Any) – GradScaler for mixed precision training.

  • test_every (int)

  • epochs (int)

  • max_seq_len (int)

  • log_wandb (bool)

  • prev_val_loss (float)

  • early_stop (Any)

  • lr_scheduler (Any)

Returns:

Average loss for the epoch.

Return type:

float

world_models.training.train_mdn_rnn.train_mdn_rnn(vae_config, mdrnn_config, use_precomputed_latents=True, use_amp=True)[source]#

Train an MDRNN model.

This function trains an MDRNN on sequence data using the provided configurations. It loads a pretrained VAE for encoding observations into latent space, then trains the MDRNN to predict future latent states given current latent states and actions.

Parameters:
  • vae_config (WMVAEConfig) – WMVAEConfig for loading pretrained VAE.

  • mdrnn_config (WMMDNRNNConfig) – WMMDNRNNConfig containing MDRNN training hyperparameters.

  • use_precomputed_latents (bool) – If True, use pre-encoded latents from disk.

  • use_amp (bool) – If True, use automatic mixed precision for memory efficiency.

Return type:

None

The training process includes:
  • Loading pretrained VAE from vae_config.logdir

  • Training for specified number of epochs

  • Validating after each epoch

  • Learning rate scheduling with ReduceLROnPlateau

  • Early stopping based on validation loss

  • Checkpointing best and current models

Example

>>> vae_config = WMVAEConfig({
...     'height': 64, 'width': 64, 'latent_size': 32, 'logdir': 'results'
... })
>>> mdrnn_config = WMMDNRNNConfig({
...     'latent_size': 32, 'action_size': 3, 'hidden_size': 256,
...     'gmm_components': 5, 'logdir': 'results'
... })
>>> train_mdn_rnn(vae_config, mdrnn_config)

Training a linear controller on latent + recurrent state with CMA-ES.

This module provides functions to train a linear controller using Covariance Matrix Adaptation Evolution Strategy (CMA-ES). The controller maps latent and hidden states to actions for the learned world model.

Reference:

Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111

world_models.training.train_controller.flatten_parameters(parameters)[source]#
Parameters:

parameters (Any)

Return type:

ndarray

world_models.training.train_controller.load_parameters(params, controller)[source]#
Parameters:
  • params (Any)

  • controller (Any)

Return type:

None

world_models.training.train_controller.slave_routine(p_queue, r_queue, e_queue, p_index, config, time_limit)[source]#

Worker process routine for parallel rollout evaluation.

Parameters:
  • p_queue (Any) – Queue containing (s_id, parameters) to evaluate.

  • r_queue (Any) – Queue where to place results (s_id, reward).

  • e_queue (Any) – End queue - when non-empty, process terminates.

  • p_index (int) – Process index for GPU assignment.

  • config (Any) – Controller configuration (must include env_name and action_size).

  • time_limit (int) – Maximum steps per episode.

Return type:

None

world_models.training.train_controller.evaluate(solutions, results, rollouts, p_queue, r_queue)[source]#

Evaluate current controller.

Parameters:
  • solutions (Any)

  • results (Any)

  • rollouts (int)

  • p_queue (Any)

  • r_queue (Any)

Return type:

Any

world_models.training.train_controller.train_controller(config)[source]#

Train a linear controller using CMA-ES.

Parameters:

config (WMControllerConfig) – WMControllerConfig containing training hyperparameters, including env_name and action_size.

Return type:

None

The training process includes:
  • Setting up parallel evaluation workers (each loads VAE + MDRNN)

  • Running CMA-ES optimization with parallel rollout evaluation

  • Evaluating and saving best controller checkpoint

world_models.training.train_jepa.main(args=None, resume_preempt=False)[source]#

Run JEPA training using a CLI argv, nested dict, or JEPAConfig instance.

This entrypoint initializes distributed context, data pipeline, masking, models, optimizers/schedulers, checkpointing, and the full epoch loop.

Parameters:
  • args (Any)

  • resume_preempt (bool)

Return type:

Any

world_models.training.train_jepa.sweep_train()[source]#

Function for WandB sweep agent.

Return type:

None

world_models.training.train_jepa.main_from_cli(argv=None)[source]#

Compose JEPA config from YAML/dot-list overrides and launch training.

Parameters:

argv (list[str] | None)

Return type:

Any

class world_models.training.train_iris.IRISTrainer(game='ALE/Pong-v5', device='cuda', seed=42, config=None)[source]#

Bases: object

Training loop for IRIS on Atari 100k benchmark.

Parameters:
  • game (str)

  • device (str)

  • seed (int)

  • config (IRISConfig | None)

preprocess_frame(frame)[source]#

Preprocess frame: resize to 64x64 and normalize.

Parameters:

frame (ndarray)

Return type:

ndarray

collect_experience(num_steps, epsilon=0.01)[source]#

Collect experience from environment.

Parameters:
  • num_steps (int) – Number of steps to collect

  • epsilon (float) – Random action probability

Returns:

Mean episode return

Return type:

float

train_epoch(epoch)[source]#

Train for one epoch.

Parameters:

epoch (int) – Current epoch number

Returns:

Dictionary of metrics

Return type:

dict

get_epsilon(epoch)[source]#

Get exploration epsilon with decay.

Parameters:

epoch (int)

Return type:

float

evaluate(num_episodes=100, render=False)[source]#

Evaluate agent performance.

Parameters:
  • num_episodes (int) – Number of evaluation episodes

  • render (bool) – If True, also return video frames and per-step latent vectors

Returns:

dict with evaluation metrics If render is True: tuple (episode_returns_array, videos_list, latents_array)

Return type:

If render is False (default)

train(total_epochs=None, eval_interval=50, save_dir='checkpoints/iris')[source]#

Full training loop.

Parameters:
  • total_epochs (int | None) – Total training epochs

  • eval_interval (int) – Evaluate every N epochs

  • save_dir (str) – Directory to save checkpoints

Return type:

None

world_models.training.train_iris.main(argv=None)[source]#

Run IRIS training with YAML config files and Hydra dot-list overrides.

Parameters:

argv (list[str] | None)

Return type:

IRISConfig

class world_models.training.train_genie.GenieConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, tokenizer_encoder_depth=12, tokenizer_decoder_depth=20, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_encoder_depth=20, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#

Bases: SerializableConfigMixin

Configuration for Genie training.

Parameters:
  • num_frames (int)

  • image_size (int)

  • in_channels (int)

  • tokenizer_vocab_size (int)

  • tokenizer_embedding_dim (int)

  • tokenizer_encoder_dim (int)

  • tokenizer_decoder_dim (int)

  • tokenizer_encoder_depth (int)

  • tokenizer_decoder_depth (int)

  • action_vocab_size (int)

  • action_embedding_dim (int)

  • action_encoder_dim (int)

  • action_encoder_depth (int)

  • action_pooling (Literal['mean', 'windowed_attention'])

  • window_attention_heads (int)

  • dynamics_dim (int)

  • dynamics_depth (int)

  • dynamics_num_heads (int)

  • batch_size (int)

  • learning_rate (float)

  • weight_decay (float)

  • warmup_steps (int)

  • max_steps (int)

  • mask_prob_min (float)

  • mask_prob_max (float)

  • sample_temperature (float)

  • maskgit_steps (int)

num_frames: int = 16#
image_size: int = 64#
in_channels: int = 3#
tokenizer_vocab_size: int = 1024#
tokenizer_embedding_dim: int = 32#
tokenizer_encoder_dim: int = 512#
tokenizer_decoder_dim: int = 1024#
tokenizer_encoder_depth: int = 12#
tokenizer_decoder_depth: int = 20#
action_vocab_size: int = 8#
action_embedding_dim: int = 32#
action_encoder_dim: int = 1024#
action_encoder_depth: int = 20#
action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
window_attention_heads: int = 1#
dynamics_dim: int = 512#
dynamics_depth: int = 8#
dynamics_num_heads: int = 8#
batch_size: int = 4#
learning_rate: float = 3e-05#
weight_decay: float = 0.0001#
warmup_steps: int = 5000#
max_steps: int = 125000#
mask_prob_min: float = 0.5#
mask_prob_max: float = 1.0#
sample_temperature: float = 2.0#
maskgit_steps: int = 25#
class world_models.training.train_genie.VideoDataset(video_paths, num_frames=16, image_size=64)[source]#

Bases: Dataset

Dataset for video data.

Parameters:
  • video_paths (list)

  • num_frames (int)

  • image_size (int)

class world_models.training.train_genie.GenieTrainer(model, config, device=None)[source]#

Bases: object

Trainer for Genie model.

Parameters:
  • model (Module)

  • config (GenieConfig)

  • device (device | None)

train_step(batch)[source]#

Single training step.

Parameters:

batch (Tensor) – (B, C, T, H, W) video batch

Returns:

Dictionary of losses

Return type:

Dict[str, Tensor | float | None]

validate(val_batch)[source]#

Validation step.

Parameters:

val_batch (Tensor) – (B, C, T, H, W) validation video batch

Returns:

Dictionary of validation metrics

Return type:

Dict[str, Tensor]

train(train_dataloader, val_dataloader=None, num_steps=None, log_interval=100, val_interval=1000)[source]#

Full training loop.

Parameters:
  • train_dataloader (DataLoader) – Training data loader

  • val_dataloader (DataLoader | None) – Validation data loader (optional)

  • num_steps (int | None) – Number of training steps (uses config.max_steps if None)

  • log_interval (int) – Logging frequency

  • val_interval (int) – Validation frequency

Return type:

None

save_checkpoint(path)[source]#

Save model checkpoint.

Parameters:

path (str)

Return type:

None

load_checkpoint(path)[source]#

Load model checkpoint.

Parameters:

path (str)

Return type:

None

world_models.training.train_genie.create_genie_trainer(config=None, device=None)[source]#

Factory function to create Genie trainer and model.

Parameters:
Return type:

Tuple[GenieTrainer, Module]

world_models.training.train_genie.main(argv=None)[source]#

Console entrypoint for Genie trainer setup.

The generic VideoDataset in this module is intentionally abstract, so this command provides a discoverable entrypoint for inspecting defaults and constructing a trainer. Use a concrete dataset script, such as scripts/train_genie_tinyworlds.py, for end-to-end data loading.

Parameters:

argv (list[str] | None)

Return type:

None

world_models.training.train_planet.train(memory, rssm, optimizer, device, N=32, H=50, beta=1.0, grads=False)[source]#

Training implementation as indicated in: Learning Latent Dynamics for Planning from Pixels arXiv:1811.04551

(a.) The Standard Variational Bound Method

using only single step predictions.

Parameters:
  • memory (Any)

  • rssm (Any)

  • optimizer (Any)

  • device (device)

  • N (int)

  • H (int)

  • beta (float)

  • grads (bool)

Return type:

dict

world_models.training.train_planet.main()[source]#

Example PlaNet/RSSM training script with rollout collection and evaluation.

Builds environment/model/policy objects, iteratively trains on replayed episodes, and periodically saves videos and checkpoints.

Return type:

None

world_models.training.train_rssm.train_rssm(memory, model, optimizer, record_grads=True)[source]#

Train an RSSM on replayed trajectories for one optimization phase.

Samples batches from memory, computes reconstruction and KL objectives across rollout steps, and returns aggregated loss metrics.

Parameters:
  • memory (Any)

  • model (Any)

  • optimizer (Any)

  • record_grads (bool)

Return type:

dict

world_models.training.train_rssm.evaluate(memory, model, path, eps)[source]#

Run one RSSM reconstruction/prediction evaluation and save visual outputs.

Decodes priors/posteriors for a sampled sequence and writes frame grids for qualitative inspection.

Parameters:
  • memory (Any)

  • model (Any)

  • path (str)

  • eps (Any)

Return type:

None

world_models.training.train_rssm.main()[source]#

Standalone training loop for RSSM with generated replay fallback support.

Initializes environment/policy/memory, trains over episodes, logs metrics, and periodically evaluates and checkpoints the model.

Return type:

None

class world_models.training.train_diamond.DiamondAgent(config)[source]#

Bases: object

DIAMOND: DIffusion As a Model Of eNvironment Dreams

RL agent trained entirely within a diffusion world model.

Parameters:

config (DiamondConfig)

classmethod from_config(config=None, **overrides)[source]#

Build a DIAMOND agent from a config object, dict, YAML file, or YAML string.

Parameters:
  • config (DiamondConfig | dict | str | Path | None)

  • overrides (Any)

Return type:

DiamondAgent

classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#

Load a DIAMOND checkpoint from a local path/directory or HF Hub.

Parameters:
  • pretrained_model_name_or_path (str | Path)

  • config (DiamondConfig | dict | str | Path | None)

  • checkpoint_filename (str | None)

  • config_filename (str)

  • repo_type (str | None)

  • revision (str | None)

  • overrides (Any)

Return type:

DiamondAgent

parameter_count(trainable_only=False)[source]#
Parameters:

trainable_only (bool)

Return type:

int

summary()[source]#
Return type:

dict

train()[source]#

Main training loop following Algorithm 1.

Return type:

None

evaluate(num_episodes=1)[source]#

Evaluate the agent.

Parameters:

num_episodes (int)

Return type:

float

save_checkpoint(path=None)[source]#

Save model checkpoint.

Parameters:

path (str | PathLike | None) – Optional path where to write the checkpoint. If path is None or a bare filename, the file is written into checkpoints/diamond/<filename>. If path contains a directory component or is an absolute/relative path, it is used directly. When path is None, the legacy behavior is preserved and the checkpoint is written to checkpoints/diamond/checkpoint.pt.

Return type:

None

load_checkpoint(path=None)[source]#

Load model checkpoint.

Parameters:

path (str | None) – Optional path to checkpoint. If None, the default checkpoints/diamond/checkpoint.pt is loaded. If a bare filename is provided, we try checkpoints/diamond/<filename>; if a path with directory components is provided we use it directly.

Return type:

None

world_models.training.train_diamond.train_diamond(game=None, seed=None, preset=None, device=None, config=None)[source]#

Train DIAMOND on a specific game or a composed experiment config.

Parameters:
  • game (str | None)

  • seed (int | None)

  • preset (str | None)

  • device (str | None)

  • config (DiamondConfig | None)

Return type:

None

world_models.training.train_diamond.main(argv=None)[source]#

Compose DIAMOND config from YAML/dot-list overrides and launch training.

Parameters:

argv (list[str] | None)

Return type:

Any

class world_models.training.rl_harness.ActorCritic(obs_shape, action_dim, hidden_dim=256)[source]#

Bases: Module

Simple actor-critic network for RL harness.

Parameters:
  • obs_shape (tuple)

  • action_dim (int)

  • hidden_dim (int)

forward(obs)[source]#

Forward pass through CNN, then actor and critic heads.

Parameters:

obs (Tensor)

Return type:

tuple[Tensor, Tensor]

get_action(obs)[source]#

Sample action from policy.

Parameters:

obs (Tensor)

Return type:

tuple[Tensor, Tensor, Tensor]

class world_models.training.rl_harness.PPOTrainer(vec_env, device='cpu', lr=0.0003, gamma=0.99, gae_lambda=0.95, clip_ratio=0.2, num_epochs=10, batch_size=64, max_grad_norm=0.5, entropy_coeff=0.01, value_coeff=0.5)[source]#

Bases: object

Simple PPO trainer for testing vectorized environments.

Parameters:
  • vec_env (TorchVectorizedEnv)

  • device (str)

  • lr (float)

  • gamma (float)

  • gae_lambda (float)

  • clip_ratio (float)

  • num_epochs (int)

  • batch_size (int)

  • max_grad_norm (float)

  • entropy_coeff (float)

  • value_coeff (float)

collect_trajectories(num_steps)[source]#

Collect trajectories using the vectorized environment.

Parameters:

num_steps (int)

Return type:

Dict[str, Tensor]

compute_gae(rewards, values, dones)[source]#

Compute Generalized Advantage Estimation.

Parameters:
  • rewards (Tensor)

  • values (Tensor)

  • dones (Tensor)

Return type:

Tensor

train_step(trajectories)[source]#

Perform one training step using PPO.

Parameters:

trajectories (Dict[str, Tensor])

Return type:

None

train(total_timesteps, log_interval=1000)[source]#

Main training loop.

Parameters:
  • total_timesteps (int)

  • log_interval (int)

Return type:

None

world_models.training.rl_harness.create_rl_harness_example()[source]#

Example function to create and run the RL harness. Usage: Call this with your environment factory.

Return type:

None

Memory, controllers, and inference operators#

class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer with image observations and transitions.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.

Key Features

  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores raw uint8 images to save memory

  • Samples sequences (not single transitions) for temporal modeling

  • Validates sampled sequences don’t span episode boundaries

Memory Layout

  • observations: (capacity, C, H, W) uint8 images

  • actions: (capacity, action_dim) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)

Sampling Process

  1. Random start index (avoiding episode boundaries)

  2. Collect sequence of length seq_len with wraparound

  3. Validate no terminal in middle of sequence

  4. Return batch of sequences

Usage with Dreamer:

buffer = ReplayBuffer(
    size=100000,           # Max transitions to store
    obs_shape=(3, 64, 64), # RGB images
    action_size=6,         # Continuous action dim
    seq_len=50,            # Sequence length for training
    batch_size=50          # Parallel sequences per batch
)

# Add transitions during interaction
buffer.add(obs, action, reward, done)

# Sample batch for world model training
obs_batch, action_batch, reward_batch, term_batch = buffer.sample()

Memory Efficiency

  • Uses uint8 for images (1 byte per pixel vs 4 for float32)

  • Sequences share observations (overlapping windows)

  • Configurable capacity based on available system memory

Note

The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.

Parameters:
  • size (int)

  • obs_shape (Tuple[int, ...])

  • action_size (int)

  • seq_len (int)

  • batch_size (int)

add(obs, ac, rew, done)[source]#

Add a transition to the buffer.

Parameters:
  • obs (dict) – Observation dict with ‘image’ key containing the observation

  • ac (ndarray) – Action taken, shape (action_size,)

  • rew (float) – Reward received, scalar

  • done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise

Return type:

None

sample()[source]#

Sample a batch of sequences for training.

Returns:

(observations, actions, rewards, terminals)
  • observations: (seq_len, batch, C, H, W)

  • actions: (seq_len, batch, action_dim)

  • rewards: (seq_len, batch)

  • terminals: (seq_len, batch)

Return type:

tuple

class world_models.memory.dreamer_memory.Memory(capacity=10000)[source]#

Bases: object

Simple deque-based memory for storing transitions.

Used by PlaNet for online planning. Stores recent transitions and provides random sampling for policy updates.

Parameters:

capacity (int) – Maximum number of transitions to store

Usage:

memory = Memory(capacity=10000)
memory.append(obs, action, reward, done, info)
batch = random.sample(memory, batch_size=32)
append(*args)[source]#

Append a transition to memory.

Parameters:

*args (Any) – Variable length argument list containing transition data. Typically (observation, action, reward, done, info).

Return type:

None

sample(batch_size)[source]#

Sample random batch of transitions from memory.

Parameters:

batch_size (int) – Number of transitions to sample.

Returns:

List of sampled transitions.

Return type:

list

class world_models.memory.dreamer_memory.Episode(observation, action=None, reward=None, terminal=None, info=None)[source]#

Bases: object

Stores a single episode for PlaNet’s imagination and planning.

An episode is a sequence of (observation, action, reward) tuples collected during environment interaction. Episodes are used for computing returns and training value functions.

Parameters:
  • obs – Initial observation

  • action (Any) – First action (optional)

  • reward (Any) – Initial reward (optional)

  • info (Any) – Additional info dict (optional)

  • observation (Any)

  • terminal (Any)

Usage:

episode = Episode(obs, info=info)
episode.append(action, obs, reward, done, info)
episodes = [episode for _ in range(num_episodes)]

# Use with Planet agent for planning
imag_state, imag_reward, imag_action = planet.imagine(episodes)
append(action, observation, reward, terminal, info=None)[source]#
Parameters:
  • action (Any)

  • observation (Any)

  • reward (Any)

  • terminal (Any)

  • info (Any)

Return type:

None

class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode.

Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.

x#

Observations collected during the episode.

Type:

list or np.ndarray

u#

Actions taken.

Type:

list or np.ndarray

r#

Rewards received.

Type:

list or np.ndarray

t#

Terminal flags (0.0 = continue, 1.0 = terminal).

Type:

list or np.ndarray

info#

Additional episode metadata.

Type:

dict

Parameters:

postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.

Example:

episode = Episode()
episode.append(obs, action, reward, False)
episode.append(obs, action, reward, True)
episode.terminate(final_obs)
print(episode.x.shape)  # Now a numpy array
property size: int#
append(obs, act, reward, terminal)[source]#
Parameters:
  • obs (Any)

  • act (Any)

  • reward (Any)

  • terminal (Any)

Return type:

None

terminate(obs)[source]#
Parameters:

obs (Any)

Return type:

None

class world_models.memory.planet_memory.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.

  • Stores complete episodes as lists of transitions

  • Samples contiguous sub-sequences for sequence models

  • Supports time-major formatting (time-first) for RNN input

  • Memory usage estimation to prevent OOM errors

Parameters:

size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).

episodes#

Collection of Episode objects.

Type:

deque

eps_lengths#

Length of each episode.

Type:

deque

size#

Total number of transitions across all episodes.

Type:

property

Example:

memory = Memory(size=100)
memory.append([episode1, episode2])
batch, lengths = memory.sample(batch_size=32, tracelen=50)
property size: int#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

Return type:

None

sample(batch_size, tracelen=1, time_first=False)[source]#

Sample random sub-sequences from stored episodes.

Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.

Parameters:
  • batch_size (int) – Number of sequences to sample.

  • tracelen (int) – Length of each sequence (default: 1).

  • time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).

Returns:

(observations, actions, rewards, terminals, lengths)
  • observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)

  • actions: (batch, tracelen, action_dim) or (tracelen, batch, …)

  • rewards: (batch, tracelen) or (tracelen, batch)

  • terminals: (batch, tracelen) or (tracelen, batch)

  • lengths: (batch,) original episode lengths for each sample

Return type:

tuple

Raises:
  • ValueError – If memory is empty or no episodes meet minimum length.

  • MemoryError – If estimated memory usage exceeds 200 MiB threshold.

class world_models.memory.iris_memory.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#

Bases: object

Replay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.

Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.

Features:
  • Ring buffer with fixed capacity (FIFO eviction when full)

  • Stores uint8 images for memory efficiency

  • Samples sequences with validation to avoid episode boundaries

  • Supports sequence sampling for temporal learning

Memory Layout:
  • observations: (capacity, C, H, W) uint8

  • actions: (capacity, action_size) float32

  • rewards: (capacity,) float32

  • terminals: (capacity,) float32

Parameters:
  • size (int) – Maximum number of transitions to store.

  • obs_shape (tuple) – Shape of observations as (C, H, W).

  • action_size (int) – Dimension of actions.

  • seq_len (int) – Length of sequences to sample (default: 20).

  • batch_size (int) – Number of sequences per batch (default: 64).

size#

Buffer capacity.

Type:

int

obs_shape#

Observation shape.

Type:

tuple

action_size#

Action dimension.

Type:

int

seq_len#

Sequence length.

Type:

int

batch_size#

Batch size.

Type:

int

steps#

Total transitions added.

Type:

int

episodes#

Number of episode terminations observed.

Type:

int

add(obs, action, reward, terminal)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray) – Observation array with shape (C, H, W).

  • action (ndarray) – Action array with shape (action_size,).

  • reward (float) – Scalar reward value.

  • terminal (bool) – Boolean indicating if episode terminated.

Return type:

None

sample_sequence(seq_len=None)[source]#

Sample a batch of sequences for world model training.

Returns:

(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)

Return type:

observations

Parameters:

seq_len (int | None)

sample_single()[source]#

Sample a single transition for online updates.

Return type:

Tuple[ndarray, ndarray, float, float]

property buffer_capacity: int#

Returns the total capacity of the buffer.

class world_models.memory.iris_memory.IRISOnPolicyBuffer(max_steps=1000)[source]#

Bases: object

On-policy buffer for collecting trajectories during environment interaction.

Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.

Useful for:
  • Collecting complete episode trajectories

  • Storing data before batch processing

  • Temporary storage during environment interaction

Parameters:

max_steps (int) – Maximum number of steps to store (default: 1000).

max_steps#

Maximum buffer capacity.

Type:

int

observations#

List of observations.

Type:

list

actions#

List of actions.

Type:

list

rewards#

List of rewards.

Type:

list

terminals#

List of terminal flags.

Type:

list

add(obs, action, reward, terminal)[source]#
Parameters:
  • obs (ndarray)

  • action (ndarray)

  • reward (float)

  • terminal (bool)

Return type:

None

clear()[source]#
Return type:

None

get_arrays()[source]#
Return type:

tuple[ndarray, ndarray, ndarray, ndarray]

RSSM-based policy for model-predictive control.

This module provides the RSSMPolicy class that implements model-predictive control using the RSSM (Recurrent State Space Model) latent dynamics model. The policy uses a Cross-Entropy Method (CEM) for planning actions in latent space.

Reference:

Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111

class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller using Cross-Entropy Method (CEM) with RSSM.

Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.

The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.

Algorithm:
  1. Initialize Gaussian distribution over action sequences

  2. Sample N candidate action sequences

  3. Rollout each sequence in RSSM latent space

  4. Score by predicted cumulative rewards

  5. Keep top K candidates, fit Gaussian to them

  6. Repeat for T iterations

  7. Execute first action from best sequence

Parameters:
  • model (Any)

  • planning_horizon (int)

  • num_candidates (int)

  • num_iterations (int)

  • top_candidates (int)

  • device (device | str)

rssm#

The RSSM world model.

N#

Number of candidate action sequences to sample.

K#

Number of top candidates to use for updating the proposal.

T#

Number of CEM iterations per planning step.

H#

Planning horizon (number of future steps to consider).

d#

Action dimensionality.

device#

Device to run computations on.

state_size#

Hidden state dimensionality.

latent_size#

Latent state dimensionality.

Example

>>> policy = RSSMPolicy(
...     model=rssm,
...     planning_horizon=12,
...     num_candidates=1000,
...     num_iterations=5,
...     top_candidates=100,
...     device='cuda'
... )
>>> policy.reset()
>>> action = policy.poll(observation)
reset()[source]#

Reset the policy state.

Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.

Return type:

None

poll(observation, explore=False)[source]#

Get action for given observation.

Parameters:
  • observation (Tensor) – Current observation tensor of shape (channels, height, width).

  • explore (bool) – If True, add exploration noise to the selected action.

Returns:

Action tensor of shape (1, action_size).

Return type:

Tensor

class world_models.controller.iris_policy.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Actor network for IRIS (Imagined Rollouts with Implicit Successor) policy.

Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.

Architecture:
  • CNN: Extracts features from input frames (3x64x64 -> 512)

  • LSTM: Processes temporal sequences with configurable layers

  • Linear: Maps hidden states to action logits

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

action_size#

Number of discrete actions.

Type:

int

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

forward(frames, hidden_state=None, burn_in_frames=None)[source]#

Forward pass through actor.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state

  • burn_in_frames (Tensor | None) – Frames to use for initializing hidden state

Returns:

Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple

Return type:

action_logits

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

get_action(frame, temperature=1.0, deterministic=False)[source]#

Get action from a single frame.

Parameters:
  • frame (Tensor) – Single frame (B, C, H, W)

  • temperature (float) – Softmax temperature (higher = more random)

  • deterministic (bool) – If True, return argmax; else sample

Returns:

Selected action indices (B,)

Return type:

action

class world_models.controller.iris_policy.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Critic network for IRIS value estimation.

Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.

Architecture:
  • CNN: Shared feature extractor with actor (3x64x64 -> 512)

  • LSTM: Temporal processing with same architecture as actor

  • Linear: Maps hidden states to scalar values

Parameters:
  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Returns:

Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.

Return type:

values

Parameters:
  • hidden_size (int)

  • num_layers (int)

  • frame_shape (Tuple[int, int, int])

forward(frames, hidden_state=None)[source]#

Forward pass through critic.

Parameters:
  • frames (Tensor) – Input frames (B, T, C, H, W)

  • hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple

Returns:

Value estimates (B, T) hidden_state: Updated (h, c) tuple

Return type:

values

init_hidden_state(batch_size, device)[source]#

Initialize LSTM hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

Tuple[Tensor, Tensor]

class world_models.controller.iris_policy.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#

Bases: Module

CNN feature extractor shared between actor and critic networks.

Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.

Architecture:
  • Conv layers: 32 -> 64 -> 128 -> 256 channels

  • Each conv has stride=2 for spatial downsampling

  • Final linear layer projects to desired output dimension

Parameters:
  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

  • output_size (int) – Size of output feature vector (default: 512).

frame_shape#

Input frame shape.

Type:

tuple

output_size#

Output feature dimension.

Type:

int

Returns:

Feature vectors with shape (B, output_size).

Return type:

features

Parameters:
  • frame_shape (Tuple[int, int, int])

  • output_size (int)

forward(x)[source]#

Extract features from frames.

Parameters:

x (Tensor) – Frames (B, C, H, W)

Returns:

Feature vectors (B, output_size)

Return type:

features

class world_models.controller.iris_policy.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#

Bases: Module

Combined policy module for IRIS (Imagined Rollouts with Implicit Successor).

Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.

Parameters:
  • action_size (int) – Number of discrete actions.

  • hidden_size (int) – LSTM hidden state size (default: 512).

  • num_layers (int) – Number of LSTM layers (default: 4).

  • frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).

actor#

The actor network for action selection.

Type:

IRISActor

hidden_size#

LSTM hidden state size.

Type:

int

num_layers#

Number of LSTM layers.

Type:

int

frame_shape#

Input frame shape.

Type:

tuple

Example

>>> policy = IRISPolicy(
...     action_size=18,
...     hidden_size=512,
...     num_layers=4,
...     frame_shape=(3, 64, 64)
... )
>>> action = policy.act(frame, temperature=1.0, deterministic=False)
forward(frames)[source]#

Get action logits from frames.

Parameters:

frames (Tensor)

Return type:

Tensor

act(frame, temperature=1.0, deterministic=False)[source]#

Sample action from policy.

Parameters:
  • frame (Tensor)

  • temperature (float)

  • deterministic (bool)

Return type:

Tensor

init_hidden(batch_size, device)[source]#

Initialize hidden state.

Parameters:
  • batch_size (int)

  • device (device)

Return type:

tuple[Tensor, Tensor]

Rollout generation utilities for World Models.

This module provides the RolloutGenerator class for collecting episode experience using trained policies in environments.

class world_models.controller.rollout_generator.RolloutGenerator(env, device, policy=None, max_episode_steps=None, episode_gen=None, name='', enable_streaming_video=False, streaming_video_path=None, streaming_video_fps=20, streaming_video_format='mp4')[source]#

Bases: object

Generator for collecting environment rollouts.

This class handles environment interactions and rollout collection, supporting both random and policy-based action selection.

Parameters:
  • env (Any)

  • device (device | str)

  • policy (Any)

  • max_episode_steps (int | None)

  • episode_gen (Any)

  • name (str)

  • enable_streaming_video (bool)

  • streaming_video_path (str | None)

  • streaming_video_fps (int)

  • streaming_video_format (str)

env#

The environment to interact with.

device#

Device to run computations on.

policy#

The policy to use for action selection (optional).

episode_gen#

Factory for creating episode objects.

name#

Name identifier for the generator.

max_episode_steps#

Maximum steps per episode.

Example

>>> generator = RolloutGenerator(
...     env=env,
...     device='cuda',
...     policy=policy,
...     max_episode_steps=1000
... )
>>> episode = generator.rollout_once()
rollout_once(random_policy=False, explore=False)[source]#

Perform a single rollout of the environment.

Parameters:
  • random_policy (bool) – If True, use random actions instead of policy.

  • explore (bool) – If True, add exploration noise to policy actions.

Returns:

Episode object containing the rollout experience.

Return type:

Episode

rollout_n(n=1, random_policy=False)[source]#

Perform multiple rollouts.

Parameters:
  • n (int) – Number of rollouts to perform.

  • random_policy (bool) – If True, use random actions.

Returns:

List of Episode objects.

Return type:

list

rollout_eval_n(n)[source]#

Perform multiple evaluation rollouts with metrics.

Parameters:

n (int) – Number of evaluation rollouts.

Returns:

Tuple of (episodes, frames, metrics).

Return type:

tuple

rollout_eval(collect_latents=False)[source]#
Parameters:

collect_latents (bool)

Return type:

tuple

class world_models.inference.operators.OperatorABC(*, device=None)[source]

Bases: Module, ABC

Structured base class for inference operators.

Operators use a consistent pipeline:

  1. preprocess converts raw inputs into tensors.

  2. forward performs model/operator-specific tensor computation.

  3. postprocess formats the final output mapping.

Subclasses may also declare input_specs and output_specs to validate required tensor keys, shapes, and dtypes. OperatorABC inherits from torch.nn.Module, so operators support to(device), train(), and eval() just like model modules.

Parameters:

device (torch.device | str | None)

input_specs: Mapping[str, TensorSpec] = {}
output_specs: Mapping[str, TensorSpec] = {}
abstractmethod preprocess(inputs)[source]

Convert raw inputs into a tensor mapping ready for forward.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

forward(inputs)[source]

Run tensor computation for this operator.

Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.

Parameters:

inputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

postprocess(outputs)[source]

Format validated forward outputs for consumers.

Parameters:

outputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

process(inputs)[source]

Process raw inputs through preprocess, forward, and postprocess stages.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

batch(inputs)[source]

Preprocess a sequence of inputs and stack matching tensor keys.

Parameters:

inputs (Sequence[Any])

Return type:

dict[str, Tensor]

to(*args, **kwargs)[source]

Move module parameters/buffers and remember the target tensor device.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

OperatorABC

classmethod validate_mapping(values, specs, *, label)[source]

Validate tensor keys, shapes, and dtypes against optional specs.

Parameters:
  • values (Mapping[str, Tensor])

  • specs (Mapping[str, TensorSpec])

  • label (str)

Return type:

None

class world_models.inference.operators.TensorSpec(shape=None, dtype=None, required=True)[source]

Bases: object

Optional tensor contract used to validate operator inputs or outputs.

Parameters:
  • shape (tuple[int | None, ...] | None) – Expected shape. Use None as a wildcard for dimensions that may vary, such as batch size.

  • dtype (dtype | None) – Expected tensor dtype.

  • required (bool) – Whether the key must be present in the mapping being validated.

shape: tuple[int | None, ...] | None = None
dtype: dtype | None = None
required: bool = True
class world_models.inference.operators.DreamerOperator(image_size=64, action_dim=6)[source]

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

preprocess(inputs)[source]

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

preprocess(inputs)[source]

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.IrisOperator(seq_length=512, vocab_size=32000)[source]

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

preprocess(inputs)[source]

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.PlaNetOperator(state_dim=32, action_dim=4)[source]

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

preprocess(inputs)[source]

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

world_models.inference.operators.get_operator(name, **kwargs)[source]

Factory function to get inference operators by name.

Parameters:
  • name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’

  • **kwargs (Any) – Operator-specific configuration

Returns:

Configured OperatorABC instance

Return type:

OperatorABC

Example

>>> op = get_operator('dreamer', image_size=64, action_dim=6)
>>> processed = op.process({'image': image, 'action': action})
class world_models.inference.operators.base.TensorSpec(shape=None, dtype=None, required=True)[source]#

Bases: object

Optional tensor contract used to validate operator inputs or outputs.

Parameters:
  • shape (tuple[int | None, ...] | None) – Expected shape. Use None as a wildcard for dimensions that may vary, such as batch size.

  • dtype (dtype | None) – Expected tensor dtype.

  • required (bool) – Whether the key must be present in the mapping being validated.

shape: tuple[int | None, ...] | None = None#
dtype: dtype | None = None#
required: bool = True#
class world_models.inference.operators.base.OperatorABC(*, device=None)[source]#

Bases: Module, ABC

Structured base class for inference operators.

Operators use a consistent pipeline:

  1. preprocess converts raw inputs into tensors.

  2. forward performs model/operator-specific tensor computation.

  3. postprocess formats the final output mapping.

Subclasses may also declare input_specs and output_specs to validate required tensor keys, shapes, and dtypes. OperatorABC inherits from torch.nn.Module, so operators support to(device), train(), and eval() just like model modules.

Parameters:

device (torch.device | str | None)

input_specs: Mapping[str, TensorSpec] = {}#
output_specs: Mapping[str, TensorSpec] = {}#
abstractmethod preprocess(inputs)[source]#

Convert raw inputs into a tensor mapping ready for forward.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

forward(inputs)[source]#

Run tensor computation for this operator.

Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.

Parameters:

inputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

postprocess(outputs)[source]#

Format validated forward outputs for consumers.

Parameters:

outputs (dict[str, Tensor])

Return type:

dict[str, Tensor]

process(inputs)[source]#

Process raw inputs through preprocess, forward, and postprocess stages.

Parameters:

inputs (Any)

Return type:

dict[str, Tensor]

batch(inputs)[source]#

Preprocess a sequence of inputs and stack matching tensor keys.

Parameters:

inputs (Sequence[Any])

Return type:

dict[str, Tensor]

to(*args, **kwargs)[source]#

Move module parameters/buffers and remember the target tensor device.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

OperatorABC

classmethod validate_mapping(values, specs, *, label)[source]#

Validate tensor keys, shapes, and dtypes against optional specs.

Parameters:
  • values (Mapping[str, Tensor])

  • specs (Mapping[str, TensorSpec])

  • label (str)

Return type:

None

class world_models.inference.operators.dreamer_operator.DreamerOperator(image_size=64, action_dim=6)[source]#

Bases: OperatorABC

Operator for Dreamer model preprocessing: normalizes observations and encodes actions.

Parameters:
  • image_size (int)

  • action_dim (int)

preprocess(inputs)[source]#

Process Dreamer inputs: image observation and action.

Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.planet_operator.PlaNetOperator(state_dim=32, action_dim=4)[source]#

Bases: OperatorABC

Operator for PlaNet model preprocessing: encodes environment states and transitions.

Parameters:
  • state_dim (int)

  • action_dim (int)

preprocess(inputs)[source]#

Process PlaNet inputs: state observations and actions.

Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.iris_operator.IrisOperator(seq_length=512, vocab_size=32000)[source]#

Bases: OperatorABC

Operator for Iris transformer model: formats sequences and embeddings.

Parameters:
  • seq_length (int)

  • vocab_size (int)

preprocess(inputs)[source]#

Process Iris inputs: token sequences and optional embeddings.

Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

class world_models.inference.operators.jepa_operator.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#

Bases: OperatorABC

Operator for JEPA model preprocessing: handles image/video masking and patch processing.

Parameters:
  • image_size (int)

  • patch_size (int)

  • mask_ratio (float)

preprocess(inputs)[source]#

Process JEPA inputs: images with masking.

Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}

Parameters:

inputs (Dict[str, Any])

Return type:

Dict[str, Tensor]

Datasets, environments, and transforms#

Environment adapters#

The environment APIs below mirror the dedicated environment guide pages: DMC, DeepMind Lab, Gym/Gymnasium, Atari/ALE, Procgen, MuJoCo, Unity ML-Agents, and vectorization utilities. DIAMOND-style Atari support is intentionally not listed as an environment adapter because it is Atari preprocessing rather than a separate environment family.

world_models.envs.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#

Create any Atari environment from Arcadic Learning Environment (ALE).

Parameters:
  • env_id (str) – The id of the Atari environment to create.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The created Atari environment.

Return type:

gym.Env

world_models.envs.list_available_atari_envs()[source]#

Get a list of all available Atari environments in Arcadic Learning Environment (ALE).

Returns:

List of available Atari environment IDs.

Return type:

list[str]

world_models.envs.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#

Create vectorized Atari environments using ALE’s native AtariVectorEnv.

Parameters:
  • game (str) – The name of the Atari game (e.g., “pong”, “breakout”).

  • num_envs (int) – Number of parallel environments.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • seed (Optional[int]) – Random seed for reproducibility.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The vectorized Atari environment.

Return type:

AtariVectorEnv

class world_models.envs.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#

Bases: object

Native MuJoCo environment adapter for pixel-based world-model training.

The adapter uses the low-level mujoco Python package directly: models are compiled from MJCF XML strings/files or MJB binaries via mujoco.MjModel; simulation state lives in mujoco.MjData; actions are written to data.ctrl; and images are produced with mujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract: {"image": uint8[C, H, W]}.

Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply reward_fn and terminal_fn callbacks. By default, rewards are 0.0 and episodes terminate only through external wrappers such as TimeLimit.

Parameters:
  • xml_path (str | Path | None)

  • xml_string (str | None)

  • binary_path (str | Path | None)

  • assets (dict[str, bytes] | None)

  • seed (int)

  • size (tuple[int, int])

  • camera (str | int | None)

  • reward_fn (RewardFn | None)

  • terminal_fn (TerminalFn | None)

  • frame_skip (int)

  • reset_noise_scale (float)

  • default_control_range (tuple[float, float])

property observation_space: Dict#
property action_space: Box#
reset(seed=None)[source]#
Parameters:

seed (int | None)

Return type:

dict[str, ndarray]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, ndarray], float, bool, dict[str, Any]]

render()[source]#
Return type:

Any

close()[source]#
Return type:

None

world_models.envs.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create one MuJoCo image environment factory for tasks and MJCF/MJB models.

Parameters:
  • model (str | Path | None) – Either a Gymnasium MuJoCo task id such as "Humanoid-v4", an MJCF XML path/string, or an MJB binary path.

  • backend (str) – "auto" infers native vs Gymnasium task mode. Use "native" for MJCF/MJB, "gymnasium" for task ids, or "robotics" for Gymnasium Robotics registrations.

  • seed (int) – Seed forwarded to the image wrapper.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make in task-id mode. Extra **kwargs are also forwarded there.

  • **kwargs (Any) – Native MuJoCoImageEnv options for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.

Returns:

A TorchWM image environment returning {"image": uint8[C, H, W]}.

Return type:

GymImageEnv | MuJoCoImageEnv

world_models.envs.make_mujoco_env_from_config(args, size)[source]#

Build a MuJoCo image environment from a DreamerConfig-like object.

Parameters:
  • args (Any)

  • size (tuple[int, int])

Return type:

Any

world_models.envs.list_gymnasium_robotics_envs()[source]#

List all Gymnasium Robotics ids registered by the installed package.

Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.

Return type:

list[str]

world_models.envs.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create a TorchWM image wrapper for a Gymnasium Robotics environment.

Parameters:
  • env (str) – Any environment id registered by gymnasium-robotics.

  • seed (int) – Seed forwarded to GymImageEnv.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode forwarded to gymnasium.make.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make.

  • **kwargs (Any) – Additional keyword arguments forwarded to gymnasium.make.

Returns:

A GymImageEnv that emits {"image": uint8[C, H, W]} observations.

Return type:

GymImageEnv

world_models.envs.register_gymnasium_robotics_envs()[source]#

Import Gymnasium Robotics so its environments are registered with Gymnasium.

Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external gymnasium-robotics package. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely on gymnasium.register_envs; this helper supports both paths.

Return type:

Any

class world_models.envs.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#

Bases: object

Gym-like environment wrapper that always returns image observations.

This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.

Features:
  • Supports environment IDs (string) and pre-built environment objects.

  • Synthesizes RGB images from vector observations for pixel-based training.

  • Exposes continuous action spaces mapped to [-1, 1] range.

  • Converts discrete actions to one-hot vectors.

  • Returns observations as dict {“image”: (C, H, W)} with uint8 values.

Parameters:
  • env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • seed (int) – Random seed for environment reset (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • render_mode (str) – Render mode for environment (default: “rgb_array”).

observation_space#

Dict space with “image” key containing (C, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode (default: 1000).

property observation_space: Dict#
property action_space: Box#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, Any]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

close()[source]#
Return type:

None

world_models.envs.make_gym_env(env, **kwargs)[source]#

Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.

Parameters:
  • env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • **kwargs (Any) – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)

Returns:

A wrapper that always returns image observations in the

format {“image”: (C, H, W)} suitable for pixel-based world models.

Return type:

GymImageEnv

class world_models.envs.WorldModelEnv(world_model, observation_space, action_space, *, initial_observation=None, initial_state=None, reset_fn=None, transition_fn=None, reward_fn=None, terminal_fn=None, render_fn=None, action_transform_fn=None, max_episode_steps=None, render_mode=None, device=None, torch_actions=True, seed=None)[source]#

Bases: Env

Expose a trained world model through the Gymnasium Env API.

WorldModelEnv keeps the current latent/model state and advances it with a transition callable or with a compatible method on world_model. The wrapper returns Gymnasium-style (obs, info) from reset and (obs, reward, terminated, truncated, info) from step, making learned model rollouts pluggable into RL libraries such as Stable-Baselines3, TorchRL, and CleanRL.

Parameters:
  • world_model (Any) – Trained model or lightweight adapter object used for simulated dynamics.

  • observation_space (gym.Space) – Gymnasium observation space emitted by the wrapper.

  • action_space (gym.Space) – Gymnasium action space accepted by the wrapper.

  • initial_observation (Any | None) – Optional observation returned when no reset callable provides one. Defaults to observation_space.sample().

  • initial_state (Any | None) – Optional latent/model state used at reset.

  • reset_fn (ResetFn | None) – Optional callable for resetting model state. Accepted return forms are obs, (obs, info), (state, obs), (state, obs, info), or a mapping with state/observation.

  • transition_fn (TransitionFn | None) – Optional callable for one model step. If omitted, the wrapper tries common methods on world_model: env_step, step, predict_step, predict, imagine_step, transition, then __call__.

  • reward_fn (RewardFn | None) – Optional callable used when the transition output omits a reward.

  • terminal_fn (TerminalFn | None) – Optional callable used when the transition output omits a termination flag.

  • render_fn (RenderFn | None) – Optional callable used by render.

  • action_transform_fn (ActionTransformFn | None) – Optional callable that converts library actions into the format expected by the world model.

  • max_episode_steps (int | None) – Optional time limit. Reaching it sets truncated.

  • render_mode (str | None) – Optional Gymnasium render mode. rgb_array is supported by default when observations contain image-like data.

  • device (Any | None) – Device used for tensor actions when torch_actions=True.

  • torch_actions (bool) – Convert actions to torch.Tensor before model calls.

  • seed (int | None) – Optional RNG seed for observation/action spaces and NumPy.

metadata = {'render_fps': 30, 'render_modes': ['rgb_array']}#
property state: Any#

Current latent/model state tracked by the wrapper.

reset(*, seed=None, options=None)[source]#

Reset the simulated rollout and return (observation, info).

Parameters:
  • seed (int | None)

  • options (dict[str, Any] | None)

Return type:

tuple[Any, dict[str, Any]]

step(action)[source]#

Roll the learned model forward for one simulated environment step.

Parameters:

action (Any)

Return type:

tuple[Any, float, bool, bool, dict[str, Any]]

render()[source]#

Render the latest simulated observation or delegate to render_fn.

Return type:

Any

close()[source]#

Close the wrapped world model if it exposes close.

Return type:

None

world_models.envs.make_world_model_env(world_model, **kwargs)[source]#

Create a WorldModelEnv from a trained model and spaces.

Parameters:
  • world_model (Any)

  • kwargs (Any)

Return type:

WorldModelEnv

class world_models.envs.ProcgenImageEnv(env, seed=0, size=(64, 64), distribution_mode='easy', num_levels=0, start_level=None, action_n=15, **procgen_kwargs)[source]#

Bases: object

Adapt Procgen’s vector API to TorchWM’s single-env image interface.

The upstream procgen.ProcgenEnv API is vectorized, so this wrapper builds a one-environment vector and unwraps the leading batch dimension. Actions are exposed as a continuous one-hot-like Box[-1, 1] with one element per discrete Procgen action, matching TorchWM’s other discrete image adapters.

Parameters:
  • env (str)

  • seed (int)

  • size (tuple[int, int])

  • distribution_mode (str)

  • num_levels (int)

  • start_level (int | None)

  • action_n (int)

  • procgen_kwargs (Any)

property observation_space: Dict#
property action_space: _ProcgenActionSpace#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, ndarray[tuple[Any, …], dtype[uint8]]]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, ndarray[tuple[Any, …], dtype[uint8]]], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray[tuple[Any, …], dtype[uint8]]

close()[source]#
Return type:

None

world_models.envs.list_procgen_envs()[source]#

Return the Procgen game names understood by ProcgenImageEnv.

Return type:

list[str]

world_models.envs.make_procgen_env(env, **kwargs)[source]#

Create a single-environment Procgen adapter.

Parameters:
  • env (str) – Procgen game name or Gym-style id.

  • **kwargs (Any) – Options forwarded to ProcgenImageEnv.

Returns:

TorchWM-compatible image wrapper exposing {"image": (3, H, W) uint8} observations and one-hot-like actions.

Return type:

ProcgenImageEnv

world_models.envs.normalize_procgen_env_name(env)[source]#

Normalize Procgen Gym ids and shorthand names to Procgen game names.

Accepted forms include "coinrun", "procgen-coinrun-v0", and "procgen:procgen-coinrun-v0".

Parameters:

env (str)

Return type:

str

class world_models.envs.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#

Bases: object

Gym-like wrapper for Unity ML-Agents environments.

Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.

Features:
  • Supports single-agent control with continuous action spaces.

  • Returns observations as {“image”: (C, H, W)} with uint8 values.

  • Normalizes actions to [-1, 1] range.

  • Includes rendered frames in observations for visual policies.

Parameters:
  • file_name (str) – Path to the Unity environment binary.

  • behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.

  • seed (int) – Random seed for environment (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • worker_id (int) – Worker ID for multi-environment setup (default: 0).

  • base_port (int) – Base port for Unity environment communication (default: 5005).

  • no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).

  • time_scale (float) – Simulation time scale multiplier (default: 20.0).

  • quality_level (int) – Graphics quality level 0-5 (default: 1).

  • max_episode_steps (int) – Maximum steps per episode (default: 1000).

observation_space#

Dict space with “image” key containing (3, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode.

Raises:
  • ValueError – If no behaviors found or action space is not continuous.

  • RuntimeError – If no agents available after reset.

Parameters:
  • file_name (str)

  • behavior_name (str | None)

  • seed (int)

  • size (tuple[int, int])

  • worker_id (int)

  • base_port (int)

  • no_graphics (bool)

  • time_scale (float)

  • quality_level (int)

  • max_episode_steps (int)

property observation_space: Dict#
property action_space: Box#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, Any]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

close()[source]#
Return type:

None

world_models.envs.make_unity_mlagents_env(env_id=None, **kwargs)[source]#

Create a Unity ML-Agents environment wrapper.

Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.

Parameters:
  • **kwargs (Any) – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).

  • env_id (str | None)

  • **kwargs

Returns:

A Gym-compatible wrapper for Unity environments.

Return type:

UnityMLAgentsEnv

class world_models.envs.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#

Bases: object

Gym-style adapter for DeepMind Control Suite tasks.

The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.

Features:
  • Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)

  • Automatically handles special cases like “cup” -> “ball_in_cup”

  • Renders RGB images at configurable resolution

  • Returns observations as dict with both state vectors and images

Parameters:
  • name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).

  • seed (int) – Random seed for environment initialization.

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.

observation_space#

Dict space with state keys and “image”.

Type:

gym.spaces.Dict

action_space#

Continuous action space from DMC spec.

Type:

gym.spaces.Box

Example

>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64))
>>> obs = env.reset()
>>> print(obs.keys())  # dict_keys(['position', 'velocity', 'image'])
property observation_space: Dict#
property action_space: Box#
step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[dict, float, bool, dict]

reset()[source]#
Return type:

dict

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

class world_models.envs.BraxImageEnv(env, seed=0, size=(64, 64), backend=None, episode_length=None, auto_reset=False, jit=True, suppress_warp_warnings=True, **env_kwargs)[source]#

Bases: object

Gym-like adapter for training TorchWM world models on Brax tasks.

Brax environments are functional JAX environments: reset consumes a PRNG key and returns a state, while step consumes the previous state plus an action and returns the next state. This adapter stores the Brax state between calls and converts state observations into image observations compatible with pixel-based TorchWM agents such as Dreamer.

If a Brax renderer is not available, vector observations are rendered as deterministic feature-band images so training code can still consume a pixel stream. The original vector observation is also exposed through info["vector_observation"] after step for diagnostics.

Parameters:
  • env (str | Any)

  • seed (int)

  • size (tuple[int, int])

  • backend (str | None)

  • episode_length (int | None)

  • auto_reset (bool)

  • jit (bool)

  • suppress_warp_warnings (bool)

  • env_kwargs (Any)

property observation_space: Space#
property action_space: Space#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, Any]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

close()[source]#
Return type:

None

world_models.envs.make_brax_env(env, **kwargs)[source]#

Create a TorchWM image wrapper for Brax environments.

Parameters:
  • env (str | Any) – Brax environment name (for example, "ant") or a pre-built Brax environment object exposing reset(rng) and step(state, action).

  • **kwargs (Any) – Additional keyword arguments passed to BraxImageEnv.

Returns:

A Gym-like wrapper that returns {"image": (C, H, W)} observations and exposes continuous actions in the Brax [-1, 1] range.

Return type:

BraxImageEnv

class world_models.envs.DMLabEnv(level, seed=0, size=(64, 64), action_repeat=4, action_set=None, observations=None, config=None, renderer='hardware', **lab_kwargs)[source]#

Bases: object

Gym-style adapter for DeepMind Lab 3D environments.

The native deepmind_lab API exposes RGB observations as HWC arrays and expects a seven-element integer action vector. This adapter presents a TorchWM-friendly image observation dict and a Box action space containing a one-hot vector in [-1, 1] so it composes with Dreamer’s normalization wrappers.

Parameters:
  • level (str)

  • seed (int)

  • size (tuple[int, int])

  • action_repeat (int)

  • action_set (Sequence[Sequence[int]] | np.ndarray | None)

  • observations (Sequence[str] | None)

  • config (dict[str, Any] | None)

  • renderer (str)

  • lab_kwargs (Any)

property observation_space: Dict#
property action_space: _OneHotActionSpace#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, ndarray]

step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[dict[str, ndarray], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

close()[source]#
Return type:

None

world_models.envs.make_dmlab_env(level, **kwargs)[source]#

Create a DeepMind Lab environment adapter for TorchWM.

Parameters:
  • level (str) – DeepMind Lab level name, for example "rooms_collect_good_objects_train".

  • **kwargs (Any) – Additional keyword arguments passed to DMLabEnv.

Returns:

A Gym-like wrapper returning {"image": (C, H, W)} uint8 observations and normalized one-hot discrete actions.

Return type:

DMLabEnv

class world_models.envs.BSuiteImageEnv(bsuite_id, seed=0, size=(64, 64), env=None)[source]#

Bases: object

Gym-like wrapper for DeepMind BSuite dm_env environments.

BSuite tasks expose compact dm_env observations and mostly discrete actions. This adapter presents a Gym/Gymnasium-style API with image observations under {"image": (C, H, W)} so TorchWM’s pixel-based world models can train and evaluate on BSuite diagnostic tasks without requiring the base environment to implement rendering.

Parameters:
  • bsuite_id (str)

  • seed (int)

  • size (tuple[int, int])

  • env (Any | None)

property observation_space: Dict#
property action_space: Space#
property max_episode_steps: int#
reset(seed=None)[source]#
Parameters:

seed (int | None)

Return type:

dict[str, ndarray]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, ndarray], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

close()[source]#
Return type:

None

world_models.envs.make_bsuite_env(bsuite_id, **kwargs)[source]#

Create a Dreamer-compatible image wrapper around a BSuite task.

Parameters:
  • bsuite_id (str)

  • kwargs (Any)

Return type:

BSuiteImageEnv

world_models.envs.list_available_bsuite_ids()[source]#

Return the installed BSuite sweep ids, or examples if BSuite is absent.

Return type:

list[str]

class world_models.envs.TimeLimit(env, duration)[source]#

Bases: object

Terminate episodes after a fixed number of wrapper steps.

If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.

Parameters:
  • env (Any)

  • duration (int)

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[Any, Any, bool, dict[str, Any]]

reset()[source]#
Return type:

Any

class world_models.envs.ActionRepeat(env, amount)[source]#

Bases: object

Repeat each action for a fixed number of environment steps.

Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.

Parameters:
  • env (Any)

  • amount (int)

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[Any, float, bool, dict[str, Any]]

class world_models.envs.NormalizeActions(env)[source]#

Bases: object

Expose a normalized [-1, 1] action space for bounded continuous controls.

Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.

Parameters:

env (Any)

property action_space: Box#
step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[Any, Any, bool, dict[str, Any]]

class world_models.envs.ObsDict(env, key='obs')[source]#

Bases: object

Convert scalar/array observations into a dictionary observation format.

This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).

Parameters:
  • env (Any)

  • key (str)

property observation_space: Dict#
property action_space: Any#
step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], Any, bool, dict[str, Any]]

reset()[source]#
Return type:

dict[str, Any]

class world_models.envs.OneHotAction(env)[source]#

Bases: object

Wrap discrete-action environments to accept one-hot action vectors.

The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.

Parameters:

env (Any)

property action_space: Box#
step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[Any, Any, bool, dict[str, Any]]

reset()[source]#
Return type:

Any

class world_models.envs.RewardObs(env)[source]#

Bases: object

Augment observations with the latest scalar reward under obs[“reward”].

Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.

Parameters:

env (Any)

property observation_space: Dict#
step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], Any, bool, dict[str, Any]]

reset()[source]#
Return type:

dict[str, Any]

class world_models.envs.ResizeImage(env, size=(64, 64))[source]#

Bases: object

Resize image-like observation entries to a target spatial size.

The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.

Parameters:
  • env (Any)

  • size (tuple[int, int])

property obs_space: dict[str, Any]#
step(action)[source]#
Parameters:

action (Any)

Return type:

Any

reset()[source]#
Return type:

Any

class world_models.envs.RenderImage(env, key='image')[source]#

Bases: object

Inject RGB renders from env.render(“rgb_array”) into observations.

This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.

Parameters:
  • env (Any)

  • key (str)

property obs_space: dict[str, Any]#
step(action)[source]#
Parameters:

action (Any)

Return type:

Any

reset()[source]#
Return type:

Any

class world_models.envs.SelectAction(env, key)[source]#

Bases: Wrapper

Gym wrapper for dictionary actions that forwards a selected key only.

This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.

Parameters:
  • env (Any)

  • key (str)

step(action)[source]#
Parameters:

action (dict[str, Any])

Return type:

Any

world_models.envs.make_env(env_id, **kwargs)[source]#

Compatibility helper: create an environment by delegating to package-specific factories when available, falling back to gym.make.

This preserves older callers that expect make_env to exist.

Parameters:
  • env_id (str)

  • kwargs (Any)

Return type:

Any

class world_models.envs.dmc.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#

Bases: object

Gym-style adapter for DeepMind Control Suite tasks.

The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.

Features:
  • Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)

  • Automatically handles special cases like “cup” -> “ball_in_cup”

  • Renders RGB images at configurable resolution

  • Returns observations as dict with both state vectors and images

Parameters:
  • name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).

  • seed (int) – Random seed for environment initialization.

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.

observation_space#

Dict space with state keys and “image”.

Type:

gym.spaces.Dict

action_space#

Continuous action space from DMC spec.

Type:

gym.spaces.Box

Example

>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64))
>>> obs = env.reset()
>>> print(obs.keys())  # dict_keys(['position', 'velocity', 'image'])
property observation_space: Dict#
property action_space: Box#
step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[dict, float, bool, dict]

reset()[source]#
Return type:

dict

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

world_models.envs.dmlab.make_dmlab_env(level, **kwargs)[source]#

Create a DeepMind Lab environment adapter for TorchWM.

Parameters:
  • level (str) – DeepMind Lab level name, for example "rooms_collect_good_objects_train".

  • **kwargs (Any) – Additional keyword arguments passed to DMLabEnv.

Returns:

A Gym-like wrapper returning {"image": (C, H, W)} uint8 observations and normalized one-hot discrete actions.

Return type:

DMLabEnv

class world_models.envs.dmlab.DMLabEnv(level, seed=0, size=(64, 64), action_repeat=4, action_set=None, observations=None, config=None, renderer='hardware', **lab_kwargs)[source]#

Bases: object

Gym-style adapter for DeepMind Lab 3D environments.

The native deepmind_lab API exposes RGB observations as HWC arrays and expects a seven-element integer action vector. This adapter presents a TorchWM-friendly image observation dict and a Box action space containing a one-hot vector in [-1, 1] so it composes with Dreamer’s normalization wrappers.

Parameters:
  • level (str)

  • seed (int)

  • size (tuple[int, int])

  • action_repeat (int)

  • action_set (Sequence[Sequence[int]] | np.ndarray | None)

  • observations (Sequence[str] | None)

  • config (dict[str, Any] | None)

  • renderer (str)

  • lab_kwargs (Any)

property observation_space: Dict#
property action_space: _OneHotActionSpace#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, ndarray]

step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[dict[str, ndarray], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

close()[source]#
Return type:

None

world_models.envs.gym_env.make_gym_env(env, **kwargs)[source]#

Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.

Parameters:
  • env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • **kwargs (Any) – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)

Returns:

A wrapper that always returns image observations in the

format {“image”: (C, H, W)} suitable for pixel-based world models.

Return type:

GymImageEnv

class world_models.envs.gym_env.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#

Bases: object

Gym-like environment wrapper that always returns image observations.

This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.

Features:
  • Supports environment IDs (string) and pre-built environment objects.

  • Synthesizes RGB images from vector observations for pixel-based training.

  • Exposes continuous action spaces mapped to [-1, 1] range.

  • Converts discrete actions to one-hot vectors.

  • Returns observations as dict {“image”: (C, H, W)} with uint8 values.

Parameters:
  • env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.

  • seed (int) – Random seed for environment reset (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • render_mode (str) – Render mode for environment (default: “rgb_array”).

observation_space#

Dict space with “image” key containing (C, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode (default: 1000).

property observation_space: Dict#
property action_space: Box#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, Any]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray

close()[source]#
Return type:

None

world_models.envs.ale_atari_env.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#

Create any Atari environment from Arcadic Learning Environment (ALE).

Parameters:
  • env_id (str) – The id of the Atari environment to create.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The created Atari environment.

Return type:

gym.Env

world_models.envs.ale_atari_env.list_available_atari_envs()[source]#

Get a list of all available Atari environments in Arcadic Learning Environment (ALE).

Returns:

List of available Atari environment IDs.

Return type:

list[str]

world_models.envs.ale_atari_vector_env.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#

Create vectorized Atari environments using ALE’s native AtariVectorEnv.

Parameters:
  • game (str) – The name of the Atari game (e.g., “pong”, “breakout”).

  • num_envs (int) – Number of parallel environments.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • seed (Optional[int]) – Random seed for reproducibility.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The vectorized Atari environment.

Return type:

AtariVectorEnv

Procgen environment adapter for TorchWM image-based agents.

world_models.envs.procgen_env.list_procgen_envs()[source]#

Return the Procgen game names understood by ProcgenImageEnv.

Return type:

list[str]

world_models.envs.procgen_env.normalize_procgen_env_name(env)[source]#

Normalize Procgen Gym ids and shorthand names to Procgen game names.

Accepted forms include "coinrun", "procgen-coinrun-v0", and "procgen:procgen-coinrun-v0".

Parameters:

env (str)

Return type:

str

world_models.envs.procgen_env.make_procgen_env(env, **kwargs)[source]#

Create a single-environment Procgen adapter.

Parameters:
  • env (str) – Procgen game name or Gym-style id.

  • **kwargs (Any) – Options forwarded to ProcgenImageEnv.

Returns:

TorchWM-compatible image wrapper exposing {"image": (3, H, W) uint8} observations and one-hot-like actions.

Return type:

ProcgenImageEnv

class world_models.envs.procgen_env.ProcgenImageEnv(env, seed=0, size=(64, 64), distribution_mode='easy', num_levels=0, start_level=None, action_n=15, **procgen_kwargs)[source]#

Bases: object

Adapt Procgen’s vector API to TorchWM’s single-env image interface.

The upstream procgen.ProcgenEnv API is vectorized, so this wrapper builds a one-environment vector and unwraps the leading batch dimension. Actions are exposed as a continuous one-hot-like Box[-1, 1] with one element per discrete Procgen action, matching TorchWM’s other discrete image adapters.

Parameters:
  • env (str)

  • seed (int)

  • size (tuple[int, int])

  • distribution_mode (str)

  • num_levels (int)

  • start_level (int | None)

  • action_n (int)

  • procgen_kwargs (Any)

property observation_space: Dict#
property action_space: _ProcgenActionSpace#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, ndarray[tuple[Any, …], dtype[uint8]]]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, ndarray[tuple[Any, …], dtype[uint8]]], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

ndarray[tuple[Any, …], dtype[uint8]]

close()[source]#
Return type:

None

world_models.envs.mujoco_env.make_mujoco_env_from_config(args, size)[source]#

Build a MuJoCo image environment from a DreamerConfig-like object.

Parameters:
  • args (Any)

  • size (tuple[int, int])

Return type:

Any

class world_models.envs.mujoco_env.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#

Bases: object

Native MuJoCo environment adapter for pixel-based world-model training.

The adapter uses the low-level mujoco Python package directly: models are compiled from MJCF XML strings/files or MJB binaries via mujoco.MjModel; simulation state lives in mujoco.MjData; actions are written to data.ctrl; and images are produced with mujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract: {"image": uint8[C, H, W]}.

Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply reward_fn and terminal_fn callbacks. By default, rewards are 0.0 and episodes terminate only through external wrappers such as TimeLimit.

Parameters:
  • xml_path (str | Path | None)

  • xml_string (str | None)

  • binary_path (str | Path | None)

  • assets (dict[str, bytes] | None)

  • seed (int)

  • size (tuple[int, int])

  • camera (str | int | None)

  • reward_fn (RewardFn | None)

  • terminal_fn (TerminalFn | None)

  • frame_skip (int)

  • reset_noise_scale (float)

  • default_control_range (tuple[float, float])

property observation_space: Dict#
property action_space: Box#
reset(seed=None)[source]#
Parameters:

seed (int | None)

Return type:

dict[str, ndarray]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, ndarray], float, bool, dict[str, Any]]

render()[source]#
Return type:

Any

close()[source]#
Return type:

None

world_models.envs.mujoco_env.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create one MuJoCo image environment factory for tasks and MJCF/MJB models.

Parameters:
  • model (str | Path | None) – Either a Gymnasium MuJoCo task id such as "Humanoid-v4", an MJCF XML path/string, or an MJB binary path.

  • backend (str) – "auto" infers native vs Gymnasium task mode. Use "native" for MJCF/MJB, "gymnasium" for task ids, or "robotics" for Gymnasium Robotics registrations.

  • seed (int) – Seed forwarded to the image wrapper.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make in task-id mode. Extra **kwargs are also forwarded there.

  • **kwargs (Any) – Native MuJoCoImageEnv options for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.

Returns:

A TorchWM image environment returning {"image": uint8[C, H, W]}.

Return type:

GymImageEnv | MuJoCoImageEnv

world_models.envs.robotics_env.is_moved_mujoco_error(exc)[source]#

Return whether Gymnasium reported the v2/v3 MuJoCo move.

Parameters:

exc (BaseException)

Return type:

bool

world_models.envs.robotics_env.register_gymnasium_robotics_envs()[source]#

Import Gymnasium Robotics so its environments are registered with Gymnasium.

Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external gymnasium-robotics package. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely on gymnasium.register_envs; this helper supports both paths.

Return type:

Any

world_models.envs.robotics_env.list_gymnasium_robotics_envs()[source]#

List all Gymnasium Robotics ids registered by the installed package.

Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.

Return type:

list[str]

world_models.envs.robotics_env.make_gymnasium_env_with_robotics_fallback(env, *, render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create a Gymnasium env and retry after Robotics registration if needed.

Parameters:
  • env (str)

  • render_mode (str)

  • gym_kwargs (dict[str, Any] | None)

  • kwargs (Any)

Return type:

Any

world_models.envs.robotics_env.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#

Create a TorchWM image wrapper for a Gymnasium Robotics environment.

Parameters:
  • env (str) – Any environment id registered by gymnasium-robotics.

  • seed (int) – Seed forwarded to GymImageEnv.

  • size (tuple[int, int]) – Target (height, width) image size.

  • render_mode (str) – Render mode forwarded to gymnasium.make.

  • gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to gymnasium.make.

  • **kwargs (Any) – Additional keyword arguments forwarded to gymnasium.make.

Returns:

A GymImageEnv that emits {"image": uint8[C, H, W]} observations.

Return type:

GymImageEnv

world_models.envs.unity_env.make_unity_mlagents_env(env_id=None, **kwargs)[source]#

Create a Unity ML-Agents environment wrapper.

Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.

Parameters:
  • **kwargs (Any) – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).

  • env_id (str | None)

  • **kwargs

Returns:

A Gym-compatible wrapper for Unity environments.

Return type:

UnityMLAgentsEnv

class world_models.envs.unity_env.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#

Bases: object

Gym-like wrapper for Unity ML-Agents environments.

Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.

Features:
  • Supports single-agent control with continuous action spaces.

  • Returns observations as {“image”: (C, H, W)} with uint8 values.

  • Normalizes actions to [-1, 1] range.

  • Includes rendered frames in observations for visual policies.

Parameters:
  • file_name (str) – Path to the Unity environment binary.

  • behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.

  • seed (int) – Random seed for environment (default: 0).

  • size (tuple) – Target image size as (height, width) (default: (64, 64)).

  • worker_id (int) – Worker ID for multi-environment setup (default: 0).

  • base_port (int) – Base port for Unity environment communication (default: 5005).

  • no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).

  • time_scale (float) – Simulation time scale multiplier (default: 20.0).

  • quality_level (int) – Graphics quality level 0-5 (default: 1).

  • max_episode_steps (int) – Maximum steps per episode (default: 1000).

observation_space#

Dict space with “image” key containing (3, H, W) Box.

action_space#

Box space with actions in [-1, 1] range.

max_episode_steps#

Maximum steps per episode.

Raises:
  • ValueError – If no behaviors found or action space is not continuous.

  • RuntimeError – If no agents available after reset.

Parameters:
  • file_name (str)

  • behavior_name (str | None)

  • seed (int)

  • size (tuple[int, int])

  • worker_id (int)

  • base_port (int)

  • no_graphics (bool)

  • time_scale (float)

  • quality_level (int)

  • max_episode_steps (int)

property observation_space: Dict#
property action_space: Box#
property max_episode_steps: int#
reset()[source]#
Return type:

dict[str, Any]

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], float, bool, dict[str, Any]]

render(*args, **kwargs)[source]#
Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

close()[source]#
Return type:

None

class world_models.envs.vector_env.SimWorker(worker_id, env_factory, num_envs, command_queue, result_queue, seed=None)[source]#

Bases: Process

Worker process that manages a batch of environment instances. Handles batched stepping for parallel rollouts.

Parameters:
  • worker_id (int)

  • env_factory (Callable)

  • num_envs (int)

  • command_queue (Queue)

  • result_queue (Queue)

  • seed (Optional[int])

run()[source]#

Main worker loop.

Return type:

None

class world_models.envs.vector_env.VectorizedEnv(env_factory, num_workers=2, envs_per_worker=4, seed=None)[source]#

Bases: ABC

Abstract base class for vectorized environments. Manages multiple worker processes for parallel simulation.

Parameters:
  • env_factory (Callable)

  • num_workers (int)

  • envs_per_worker (int)

  • seed (Optional[int])

abstractmethod step_batch(actions)[source]#

Step all environments with batched actions.

Parameters:

actions (Tensor)

Return type:

Dict[str, Any]

abstractmethod reset_batch()[source]#

Reset all environments.

Return type:

Dict[str, Any]

render_batch()[source]#

Render all environments.

Return type:

List[ndarray]

close()[source]#

Shutdown all workers.

Return type:

None

class world_models.envs.vector_env.TorchVectorizedEnv(*args, **kwargs)[source]#

Bases: VectorizedEnv

TorchWM-compatible vectorized environment. Returns batched tensors suitable for PyTorch training.

Parameters:
  • args (Any)

  • kwargs (Any)

step_batch(actions)[source]#

Step all environments with batched actions.

Parameters:

actions (Tensor) – Tensor of shape (total_envs, action_dim)

Returns:

Dict with ‘obs’, ‘reward’, ‘done’, ‘info’ tensors

Return type:

Dict[str, Any]

reset_batch()[source]#

Reset all environments and return initial observations.

Return type:

Dict[str, Any]

class world_models.envs.wrappers.TimeLimit(env, duration)[source]#

Bases: object

Terminate episodes after a fixed number of wrapper steps.

If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.

Parameters:
  • env (Any)

  • duration (int)

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[Any, Any, bool, dict[str, Any]]

reset()[source]#
Return type:

Any

class world_models.envs.wrappers.ActionRepeat(env, amount)[source]#

Bases: object

Repeat each action for a fixed number of environment steps.

Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.

Parameters:
  • env (Any)

  • amount (int)

step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[Any, float, bool, dict[str, Any]]

class world_models.envs.wrappers.NormalizeActions(env)[source]#

Bases: object

Expose a normalized [-1, 1] action space for bounded continuous controls.

Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.

Parameters:

env (Any)

property action_space: Box#
step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[Any, Any, bool, dict[str, Any]]

class world_models.envs.wrappers.ObsDict(env, key='obs')[source]#

Bases: object

Convert scalar/array observations into a dictionary observation format.

This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).

Parameters:
  • env (Any)

  • key (str)

property observation_space: Dict#
property action_space: Any#
step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], Any, bool, dict[str, Any]]

reset()[source]#
Return type:

dict[str, Any]

class world_models.envs.wrappers.OneHotAction(env)[source]#

Bases: object

Wrap discrete-action environments to accept one-hot action vectors.

The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.

Parameters:

env (Any)

property action_space: Box#
step(action)[source]#
Parameters:

action (ndarray)

Return type:

tuple[Any, Any, bool, dict[str, Any]]

reset()[source]#
Return type:

Any

class world_models.envs.wrappers.RewardObs(env)[source]#

Bases: object

Augment observations with the latest scalar reward under obs[“reward”].

Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.

Parameters:

env (Any)

property observation_space: Dict#
step(action)[source]#
Parameters:

action (Any)

Return type:

tuple[dict[str, Any], Any, bool, dict[str, Any]]

reset()[source]#
Return type:

dict[str, Any]

class world_models.envs.wrappers.ResizeImage(env, size=(64, 64))[source]#

Bases: object

Resize image-like observation entries to a target spatial size.

The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.

Parameters:
  • env (Any)

  • size (tuple[int, int])

property obs_space: dict[str, Any]#
step(action)[source]#
Parameters:

action (Any)

Return type:

Any

reset()[source]#
Return type:

Any

class world_models.envs.wrappers.RenderImage(env, key='image')[source]#

Bases: object

Inject RGB renders from env.render(“rgb_array”) into observations.

This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.

Parameters:
  • env (Any)

  • key (str)

property obs_space: dict[str, Any]#
step(action)[source]#
Parameters:

action (Any)

Return type:

Any

reset()[source]#
Return type:

Any

class world_models.envs.wrappers.UUID(env)[source]#

Bases: Wrapper

Gym wrapper that tracks a unique run identifier per environment reset.

The ID combines timestamp and UUID and can be used to tag episodes or artifacts generated during data collection.

Parameters:

env (Any)

reset(**kwargs)[source]#
Parameters:

kwargs (Any)

Return type:

Any

class world_models.envs.wrappers.SelectAction(env, key)[source]#

Bases: Wrapper

Gym wrapper for dictionary actions that forwards a selected key only.

This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.

Parameters:
  • env (Any)

  • key (str)

step(action)[source]#
Parameters:

action (dict[str, Any])

Return type:

Any

Atari preprocessing helpers#

These helpers wrap Atari environments for specific training recipes. They are not separate environment families.

class world_models.envs.diamond_atari.DiamondAtariWrapper(env, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64))[source]#

Bases: Wrapper

Atari wrapper for DIAMOND following the paper specifications: - frameskip: number of frames to skip (default 4) - max_noop: maximum number of noop actions at reset (default 30) - terminate_on_life_loss: terminate episode when life is lost (default True) - reward_clip: clip rewards to [-1, 0, 1] (default True) - resize: resize observations to specified size (default 64x64)

Parameters:
  • env (Env)

  • frameskip (int)

  • max_noop (int)

  • terminate_on_life_loss (bool)

  • reward_clip (bool)

  • resize (Tuple[int, int] | None)

step(action)[source]#

Step the environment.

For backwards compatibility with older gym APIs this wrapper returns a 4-tuple: (obs, reward, done, info). Internally it supports gymnasium’s 5-tuple and collapses (terminated, truncated) into a single done bool.

Parameters:

action (int)

Return type:

Any

reset(**kwargs)[source]#
Parameters:

kwargs (Any)

Return type:

Tuple[Any, Dict[str, Any]]

world_models.envs.diamond_atari.make_diamond_atari_env(game, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64), seed=None)[source]#

Create a DIAMOND-compatible Atari environment.

Parameters:
  • game (str) – Atari game name (e.g., “Breakout-v5”)

  • frameskip (int) – Number of frames to skip between actions

  • max_noop (int) – Maximum number of noop actions at reset

  • terminate_on_life_loss (bool) – Whether to terminate on life loss

  • reward_clip (bool) – Whether to clip rewards to [-1, 0, 1]

  • resize (Tuple[int, int]) – Target size for observations

  • seed (int | None) – Random seed

Returns:

Configured Atari environment

Return type:

DiamondAtariWrapper

Datasets and transforms#

Data generation and dataset classes for World Models.

This module provides utilities for generating rollout data from environments and PyTorch dataset classes for loading observation sequences.

class world_models.datasets.wm_dataset.RolloutDataset(root, transform, train=True, buffer_size=1000, num_test_files=600)[source]#

Bases: Dataset

PyTorch Dataset for loading rollout data.

This dataset loads pre-collected rollout trajectories from disk, providing a buffer-based mechanism for efficient data loading. It supports train/test splits and custom transforms.

Parameters:
  • root (str)

  • transform (Compose)

  • train (bool)

  • buffer_size (int)

  • num_test_files (int)

root#

Root directory containing rollout .npz files.

transform#

Albumentations transform to apply to observations.

train#

If True, use training split; otherwise use test split.

buffer_size#

Maximum number of files to keep in memory.

num_test_files#

Number of files to use for test set.

Example

>>> transform = transforms.Compose([transforms.ToTensor()])
>>> dataset = RolloutDataset(
...     root='data/carracing',
...     transform=transform,
...     train=True,
...     buffer_size=100,
... )
>>> obs, action, reward, terminal = dataset[0]
load_next_buffer()[source]#

Load the next batch of rollout files into memory.

This method implements a circular buffer, loading buffer_size files at a time and advancing through the dataset sequentially.

Return type:

None

class world_models.datasets.wm_dataset.ObservationDataset(root, transform, train=True, buffer_size=1000, num_test_files=600)[source]#

Bases: RolloutDataset

Dataset for single observation samples (not sequences).

This dataset extends RolloutDataset to provide individual observations rather than sequences, suitable for VAE training.

Example

>>> dataset = ObservationDataset(
...     root='data/carracing',
...     transform=transform,
...     train=True,
... )
>>> obs = dataset[0]
Parameters:
  • root (str)

  • transform (Compose)

  • train (bool)

  • buffer_size (int)

  • num_test_files (int)

class world_models.datasets.wm_dataset.SequenceDataset(root, transform, train, buffer_size, num_test_files, seq_len)[source]#

Bases: RolloutDataset

Dataset for sequential rollout data.

This dataset provides sequences of observations, actions, rewards, and terminal flags suitable for training recurrent models like MDRNN.

Parameters:
  • root (str)

  • transform (Compose)

  • train (bool)

  • buffer_size (int)

  • num_test_files (int)

  • seq_len (int)

seq_len#

Length of sequences to return.

Example

>>> dataset = SequenceDataset(
...     root='data/carracing',
...     transform=transform,
...     train=True,
...     seq_len=32,
... )
>>> obs, action, reward, terminal, next_obs = dataset[0]
class world_models.datasets.wm_dataset.LatentSequenceDataset(latents_arr, actions, rewards, terminals, train, buffer_size, num_test_files, seq_len)[source]#

Bases: Dataset

Dataset for pre-computed latent sequences.

This dataset uses pre-encoded latent representations instead of raw images, which significantly reduces memory usage during RNN training.

Parameters:
  • latents_arr (ndarray)

  • actions (ndarray)

  • rewards (ndarray)

  • terminals (ndarray)

  • train (bool)

  • buffer_size (int)

  • num_test_files (int)

  • seq_len (int)

class world_models.datasets.video_datasets.DatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True)[source]#

Bases: object

Base configuration for datasets.

Parameters:
  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • pin_memory (bool)

  • shuffle (bool)

num_frames: int = 16#
image_size: int = 64#
batch_size: int = 4#
num_workers: int = 4#
pin_memory: bool = True#
shuffle: bool = True#
class world_models.datasets.video_datasets.VideoDatasetBase(data_source, num_frames=16, image_size=64, transform=None, normalize=True)[source]#

Bases: Dataset

Base class for video datasets.

All video datasets should inherit from this class and implement the _load_video method.

Parameters:
  • data_source (str | Path | List[str] | List[Path])

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

data_source: str | Path | List[str] | List[Path]#
video_paths: Sequence[Path | int]#
class world_models.datasets.video_datasets.VideoFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.mp4', '.avi', '.mkv', '.webm', '.mov'), recursive=True)[source]#

Bases: VideoDatasetBase

Dataset that loads videos from a folder.

Supports common video formats: .mp4, .avi, .mkv, .webm

Usage:

dataset = VideoFolderDataset(
    data_source="/path/to/videos",
    num_frames=16,
    image_size=64
)
Parameters:
  • data_source (str | Path | List[str] | List[Path])

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • extensions (Tuple[str, ...])

  • recursive (bool)

class world_models.datasets.video_datasets.ImageFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.jpg', '.jpeg', '.png', '.bmp'), image_sort_key=None)[source]#

Bases: VideoDatasetBase

Dataset that loads image sequences from folders.

Each subfolder is treated as a video sequence.

Usage:

dataset = ImageFolderDataset(
    data_source="/path/to/images",
    num_frames=16,
    image_size=64
)
Parameters:
  • data_source (str | Path | List[str] | List[Path])

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • extensions (Tuple[str, ...])

  • image_sort_key (Callable | None)

class world_models.datasets.video_datasets.NumPyDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key=None)[source]#

Bases: VideoDatasetBase

Dataset that loads videos from numpy files.

Supports .npy and .npz files.

Usage:

dataset = NumPyDataset(
    data_source="/path/to/videos.npy",
    num_frames=16,
    image_size=64
)
Parameters:
  • data_source (str | Path)

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • key (str | None)

class world_models.datasets.video_datasets.RLEnvironmentDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, obs_key='observations')[source]#

Bases: VideoDatasetBase

Dataset for RL environment recordings.

Loads trajectories stored as: - .npz files with ‘observations’ and ‘actions’ keys - Directory with episode folders

Usage:

dataset = RLEnvironmentDataset(
    data_source="/path/to/rl_episodes",
    num_frames=16,
    image_size=64
)
Parameters:
  • data_source (str | Path)

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • obs_key (str)

class world_models.datasets.video_datasets.HDF5Dataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key='videos', memmap=False)[source]#

Bases: VideoDatasetBase

Dataset that loads videos from HDF5 files.

Supports pre-processed video datasets stored in HDF5 format. Expected structure: HDF5 file with ‘videos’ dataset of shape (N, T, H, W, C) or (N, T, C, H, W).

Usage:

dataset = HDF5Dataset(
    data_source="/path/to/videos.h5",
    num_frames=16,
    image_size=64
)
Parameters:
  • data_source (str | Path)

  • num_frames (int)

  • image_size (int)

  • transform (Callable | None)

  • normalize (bool)

  • key (str)

  • memmap (bool)

world_models.datasets.video_datasets.create_video_dataloader(dataset_type, data_source, num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, pin_memory=True, **kwargs)[source]#

Factory function to create video dataloaders.

Parameters:
  • dataset_type (str) – Type of dataset (“video_folder”, “image_folder”, “numpy”, “rl”)

  • data_source (str | Path | List[str]) – Path or list of paths to data

  • num_frames (int) – Number of frames per video

  • image_size (int) – Target image size (height and width)

  • batch_size (int) – Batch size for dataloader

  • num_workers (int) – Number of workers for data loading

  • shuffle (bool) – Whether to shuffle data

  • pin_memory (bool) – Whether to pin memory for faster GPU transfer

  • **kwargs (Any) – Additional arguments for specific dataset types

Returns:

Tuple of (dataset, dataloader)

Return type:

Tuple[Dataset, DataLoader]

Usage:

dataset, loader = create_video_dataloader(
    dataset_type="video_folder",
    data_source="/path/to/videos",
    num_frames=16,
    image_size=64,
    batch_size=4
)
class world_models.datasets.video_datasets.VideoDatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True, dataset_type='video_folder', data_source='', extensions=('.mp4', '.avi', '.mkv'), recursive=True, obs_key='observations')[source]#

Bases: DatasetConfig

Configuration for video datasets.

Parameters:
  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • pin_memory (bool)

  • shuffle (bool)

  • dataset_type (str)

  • data_source (str)

  • extensions (Tuple[str, ...])

  • recursive (bool)

  • obs_key (str)

dataset_type: str = 'video_folder'#
data_source: str = ''#
extensions: Tuple[str, ...] = ('.mp4', '.avi', '.mkv')#
recursive: bool = True#
obs_key: str = 'observations'#
world_models.datasets.video_datasets.create_video_dataset_from_config(config)[source]#

Create video dataset and dataloader from config.

Parameters:

config (VideoDatasetConfig)

Return type:

Tuple[Dataset, DataLoader]

TinyWorlds Dataset Loaders

Loads pre-processed video datasets from HuggingFace for training Genie-style world models. Based on: AlmondGod/tinyworlds

Available datasets: - PICO_DOOM: Minimal Doom gameplay - PONG: Classic Pong - ZELDA: Zelda Ocarina of Time (2D) - POLE_POSITION: Racing game - SONIC: Sonic the Hedgehog

class world_models.datasets.tinyworlds.TinyWorldsConfig(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, cache_dir=None, split='train')[source]#

Bases: object

Configuration for TinyWorlds datasets.

Parameters:
  • dataset_name (str)

  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • cache_dir (str | None)

  • split (str)

dataset_name: str = 'SONIC'#
num_frames: int = 16#
image_size: int = 64#
batch_size: int = 4#
num_workers: int = 4#
cache_dir: str | None = None#
split: str = 'train'#
class world_models.datasets.tinyworlds.TinyWorldsDataset(dataset_name='SONIC', num_frames=16, image_size=64, split='train', cache_dir=None, download=True, data_file=None)[source]#

Bases: Dataset

Dataset for TinyWorlds game video data.

Loads pre-processed frames from HuggingFace datasets repository.

Parameters:
  • dataset_name (str)

  • num_frames (int)

  • image_size (int)

  • split (str)

  • cache_dir (str | None)

  • download (bool)

  • data_file (str | None)

DATASET_CONFIGS = {'PICO_DOOM': {'description': 'Minimal Doom gameplay', 'filename': 'picodoom_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'POLE_POSITION': {'description': 'Racing game', 'filename': 'pole_position_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'PONG': {'description': 'Classic Pong', 'filename': 'pong_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'SONIC': {'description': 'Sonic the Hedgehog', 'filename': 'sonic_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'ZELDA': {'description': 'Zelda Ocarina of Time (2D)', 'filename': 'zelda_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}}#
get_info()[source]#

Return dataset information.

Return type:

Dict[str, Any]

class world_models.datasets.tinyworlds.TinyWorldsDataLoader[source]#

Bases: object

Factory class for creating TinyWorlds dataloaders.

DATASET_NAMES = ['PICO_DOOM', 'PONG', 'ZELDA', 'POLE_POSITION', 'SONIC']#
static create_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#
Parameters:
  • dataset_name (str)

  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • shuffle (bool)

  • cache_dir (str | None)

  • download (bool)

  • data_file (str | None)

Return type:

Tuple[TinyWorldsDataset, DataLoader]

static list_available_datasets()[source]#

List all available dataset names.

Return type:

List[str]

static get_dataset_info(dataset_name)[source]#

Get information about a specific dataset without downloading.

Parameters:

dataset_name (str)

Return type:

Dict

world_models.datasets.tinyworlds.create_tinyworlds_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#
Parameters:
  • dataset_name (str)

  • num_frames (int)

  • image_size (int)

  • batch_size (int)

  • num_workers (int)

  • shuffle (bool)

  • cache_dir (str | None)

  • download (bool)

  • data_file (str | None)

Return type:

Tuple[TinyWorldsDataset, DataLoader]

world_models.datasets.tinyworlds.download_all_datasets(cache_dir=None)[source]#

Download all available TinyWorlds datasets.

Parameters:

cache_dir (str | None) – Directory to cache downloaded datasets

Returns:

Dictionary mapping dataset names to local file paths

Return type:

Dict[str, str | None]

class world_models.datasets.diamond_dataset.ReplayBuffer(capacity=1000, obs_shape=(64, 64, 3), action_dim=1, device='cpu')[source]#

Bases: object

Replay buffer for storing environment interactions. Stores (observation, action, reward, done, next_observation) tuples.

Parameters:
  • capacity (int)

  • obs_shape (Tuple[int, int, int])

  • action_dim (int)

  • device (str)

add(obs, action, reward, done, next_obs)[source]#

Add a transition to the buffer.

Parameters:
  • obs (ndarray)

  • action (int)

  • reward (float)

  • done (bool)

  • next_obs (ndarray)

Return type:

None

sample(batch_size)[source]#

Sample a random batch of transitions.

Parameters:

batch_size (int)

Return type:

Dict[str, Tensor]

sample_sequence(batch_size, sequence_length, burn_in=0)[source]#

Sample a sequence of transitions for training.

Parameters:
  • batch_size (int) – Number of sequences to sample

  • sequence_length (int) – Total sequence length (burn_in + horizon)

  • burn_in (int) – Number of initial frames to use for conditioning

Returns:

Dictionary with tensors of shape (batch_size, sequence_length, …)

Return type:

Dict[str, Tensor]

is_ready(min_size)[source]#

Check if buffer has enough samples.

Parameters:

min_size (int)

Return type:

bool

state_dict()[source]#

Return a serializable state dict for checkpointing.

Contains numpy arrays and scalar metadata so it can be saved with torch.save or numpy.save.

Return type:

dict

load_state_dict(state)[source]#

Load state previously produced by state_dict().

This will resize internal arrays if the saved capacity differs from the current buffer capacity.

Parameters:

state (dict)

Return type:

None

class world_models.datasets.diamond_dataset.SequenceDataset(replay_buffer, sequence_length=5, burn_in=4)[source]#

Bases: Dataset

PyTorch Dataset for sampling sequences from the replay buffer. Used for training the diffusion world model.

Parameters:
  • replay_buffer (ReplayBuffer)

  • sequence_length (int)

  • burn_in (int)

world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]#

Create CIFAR-10 dataset and distributed dataloader.

Factory function that creates a CIFAR-10 dataset with the provided transforms and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or diffusion training pipelines.

Parameters:
  • transform (Any) – Transforms to apply to images (e.g., RandomCrop, ColorJitter).

  • batch_size (int) – Number of samples per batch.

  • collator (callable, optional) – Custom collate function for batching (e.g., mask collator for JEPA).

  • pin_mem (bool) – Whether to pin memory for faster GPU transfer (default: True).

  • num_workers (int) – Number of data loading workers (default: 8).

  • world_size (int) – Number of distributed processes (default: 1).

  • rank (int) – Rank of current process in distributed setting (default: 0).

  • root_path (str, optional) – Path to store/load CIFAR-10 data.

  • drop_last (bool) – Whether to drop incomplete final batch (default: True).

  • train (bool) – Whether to load train or test split (default: True).

  • download (bool) – Whether to download dataset if not present (default: False).

Returns:

(dataset, dataloader, sampler)
  • dataset: torchvision.datasets.CIFAR10 instance

  • dataloader: torch.utils.data.DataLoader with distributed sampling

  • sampler: torch.utils.data.distributed.DistributedSampler

Return type:

tuple

Example

>>> transform = make_transforms(crop_size=224)
>>> dataset, loader, sampler = make_cifar10(
...     transform=transform,
...     batch_size=256,
...     root_path="./data",
...     download=True
... )
world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]#

Build an ImageNet-1K dataset and dataloader with distributed sampling.

Factory function that creates an ImageNet dataset and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or other self-supervised training pipelines.

Supports:
  • Optional data staging from network storage to local scratch

  • Subset filtering via text file listing allowed image IDs

  • Distributed sampling for multi-GPU training

Parameters:
  • transform (Any) – Transforms to apply to images.

  • batch_size (int) – Number of samples per batch.

  • collator (callable, optional) – Custom collate function (e.g., mask collator).

  • pin_mem (bool) – Whether to pin memory for GPU transfer (default: True).

  • num_workers (int) – Number of data loading workers (default: 8).

  • world_size (int) – Number of distributed processes (default: 1).

  • rank (int) – Rank of current process (default: 0).

  • root_path (str, optional) – Root path containing ImageNet data.

  • image_folder (str, optional) – Subfolder containing ImageNet data.

  • training (bool) – Load train or validation split (default: True).

  • copy_data (bool) – Copy data locally for faster loading (default: False).

  • drop_last (bool) – Drop incomplete final batch (default: True).

  • subset_file (str, optional) – Path to file listing allowed image IDs.

Returns:

(dataset, dataloader, sampler)
  • dataset: ImageNet dataset instance

  • dataloader: DataLoader with distributed sampling

  • sampler: DistributedSampler instance

Return type:

tuple

class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]#

Bases: ImageFolder

ImageNet dataset wrapper with optional local copy/extract workflow.

Extends torchvision.datasets.ImageFolder to support data staging from network storage to local scratch space for faster multi-process training on cluster environments (e.g., SLURM).

Features:
  • Optional data copying from network storage to local /scratch

  • Extracts tar archives automatically on first access

  • Supports train/validation splits

  • Optional target indexing for balanced sampling

Parameters:
  • root (str)

  • image_folder (str)

  • tar_file (str)

  • transform (Any)

  • train (bool)

  • job_id (str | None)

  • local_rank (int | None)

  • copy_data (bool)

  • index_targets (bool)

class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]#

Bases: object

View over an ImageNet dataset filtered by an explicit image-id list.

The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.

Parameters:
  • dataset (Any)

  • subset_file (str)

filter_dataset_(subset_file)[source]#

Filter self.dataset to a subset

Parameters:

subset_file (str)

Return type:

None

property classes: Any#
world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]#

Copy and extract ImageNet archives to per-job local scratch storage.

In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.

Parameters:
  • root (str)

  • suffix (str)

  • image_folder (str)

  • tar_file (str)

  • job_id (str | None)

  • local_rank (int | None)

Return type:

str | None

world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]#

Create an ImageFolder dataset loader for custom folder-structured datasets.

Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.

Parameters:
  • transform (Any)

  • batch_size (int)

  • collator (Any)

  • pin_mem (bool)

  • num_workers (int)

  • world_size (int)

  • rank (int)

  • root_path (str | None)

  • image_folder (str | None)

  • drop_last (bool)

  • val_split (float | None)

Return type:

Tuple[Dataset, DataLoader, DistributedSampler]

PyTorch Dataset for the NuPlan autonomous driving dataset.

Requires nuplan-devkit and a local copy of the NuPlan dataset. Download from https://www.nuplan.org/nuplan and set NUPLAN_DATA_ROOT to the extracted path (default: ~/nuplan/dataset).

class world_models.datasets.nuplan.NuPlanSample(scenario_name, map_raster, ego_past, ego_future, agents_past, agents_future, agents_mask, agent_types, planning_target)[source]#

Bases: object

A single training sample from the NuPlan dataset.

Parameters:
  • scenario_name (str)

  • map_raster (Tensor)

  • ego_past (Tensor)

  • ego_future (Tensor)

  • agents_past (Tensor)

  • agents_future (Tensor)

  • agents_mask (Tensor)

  • agent_types (Tensor)

  • planning_target (Tensor)

scenario_name: str#
map_raster: Tensor#
ego_past: Tensor#
ego_future: Tensor#
agents_past: Tensor#
agents_future: Tensor#
agents_mask: Tensor#
agent_types: Tensor#
planning_target: Tensor#
class world_models.datasets.nuplan.NuPlanDataset(data_root=None, map_root=None, split='train', db_files=None, map_version='nuplan-maps-v1.0', planning_horizon=80, past_horizon=20, map_extent=(100.0, 100.0), map_resolution=0.1, max_agents=32, limit_scenarios=None)[source]#

Bases: Dataset[NuPlanSample]

PyTorch Dataset over NuPlan scenarios for world model training.

Each sample contains rasterised map tiles, ego and agent history, and future planning targets at 10 Hz.

Parameters:
  • data_root (str | Path | None) – Path to the NuPlan dataset root. Defaults to $NUPLAN_DATA_ROOT.

  • map_root (str | Path | None) – Path to NuPlan map data. Defaults to $NUPLAN_MAP_ROOT.

  • split (str) – "train", "val", or "test". The mini split is used automatically when data_root / "mini" exists.

  • db_files (list[str] | None) – Explicit list of .db files. When None the builder auto-discovers files under data_root / split.

  • map_version (str) – Map version string, e.g. "nuplan-maps-v1.0".

  • planning_horizon (int) – Number of future steps at 10 Hz (default 80 = 8 s).

  • past_horizon (int) – Number of past steps at 10 Hz (default 20 = 2 s).

  • map_extent (Tuple[float, float]) – Raster crop half-extent in metres (width, height).

  • map_resolution (float) – Metres per pixel for the raster.

  • max_agents (int) – Maximum agents per sample; fewer are zero-padded.

  • limit_scenarios (int | None) – Cap on total scenarios (useful for prototyping).

world_models.datasets.nuplan.make_nuplan_dataloader(data_root=None, split='train', batch_size=32, num_workers=4, **dataset_kwargs)[source]#

Create a NuPlan DataLoader.

Parameters:
  • data_root (str | Path | None) – Root of the NuPlan dataset (default: $NUPLAN_DATA_ROOT).

  • split (str) – Dataset split.

  • batch_size (int) – Batch size.

  • num_workers (int) – Worker count for the DataLoader.

  • **dataset_kwargs (Any) – Extra arguments forwarded to NuPlanDataset.

Return type:

(dataset, dataloader)

world_models.transforms.image.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]#

Compose image augmentations and normalization for vision model training.

Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.

Parameters:
  • crop_size (int)

  • crop_scale (tuple[float, float])

  • color_jitter (float)

  • horizontal_flip (bool)

  • color_distortion (bool)

  • gaussian_blur (bool)

  • normalization (tuple[tuple[float, ...], tuple[float, ...]])

Return type:

Any

class world_models.transforms.image.GaussianBlur(p=0.5, radius_min=0.1, radius_max=2.0)[source]#

Bases: object

Probabilistic Gaussian blur augmentation for PIL images.

Applies blur with random radius in a configurable range when sampled.

Parameters:
  • p (float)

  • radius_min (float)

  • radius_max (float)

Masking and JEPA helpers#

Masks sub-module - Masking strategies for JEPA and masked training.

This package provides various masking collator classes for generating encoder/predictor masks during masked representation learning.

Usage:

from world_models.masks import MaskCollator, DefaultCollator collator = MaskCollator(input_size=(64, 64), patch_size=8)

world_models.masks.MultiblockMaskCollator#

alias of MaskCollator

world_models.masks.RandomMaskCollator#

alias of MaskCollator

class world_models.masks.DefaultCollator[source]#

Bases: object

Simple collator that returns batch data and no masking metadata.

This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.

class world_models.masks.default.DefaultCollator[source]#

Bases: object

Simple collator that returns batch data and no masking metadata.

This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.

class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]#

Bases: object

Generate multi-block encoder and predictor masks for JEPA training.

For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.

Parameters:
  • input_size (tuple[int, int])

  • patch_size (int)

  • enc_mask_scale (tuple[float, float])

  • pred_mask_scale (tuple[float, float])

  • aspect_ratio (tuple[float, float])

  • nenc (int)

  • npred (int)

  • min_keep (int)

  • allow_overlap (bool)

step()[source]#
Return type:

int

class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]#

Bases: object

Generate random context/prediction patch splits for masked training.

A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.

Parameters:
  • ratio (tuple)

  • input_size (tuple)

  • patch_size (int)

step()[source]#
Return type:

int

world_models.helpers.jepa_helper.load_checkpoint(device, r_path, encoder, predictor, target_encoder, opt, scaler)[source]#

Load JEPA training state from disk into model and optimizer objects.

Restores encoder, predictor, optional target encoder, optimizer state, and optional AMP scaler, returning the resumed epoch for training restart.

Parameters:
  • device (device)

  • r_path (str)

  • encoder (Module)

  • predictor (Module)

  • target_encoder (Module | None)

  • opt (Optimizer)

  • scaler (Any | None)

Return type:

tuple

world_models.helpers.jepa_helper.init_model(device, patch_size=16, model_name='vit_base', crop_size=224, pred_depth=6, pred_emb_dim=384)[source]#

Initialize JEPA encoder and predictor modules with ViT backbones.

Applies truncated-normal parameter initialization, moves modules to the requested device, and returns (encoder, predictor).

Parameters:
  • device (device)

  • patch_size (int)

  • model_name (str)

  • crop_size (int)

  • pred_depth (int)

  • pred_emb_dim (int)

Return type:

tuple

world_models.helpers.jepa_helper.init_opt(encoder, predictor, iterations_per_epoch, start_lr, ref_lr, warmup, num_epochs, wd=1e-06, final_wd=1e-06, final_lr=0.0, use_bfloat16=False, ipe_scale=1.25)[source]#

Build optimizer, AMP scaler, LR scheduler, and weight-decay scheduler for JEPA.

Parameters are grouped to exclude bias/norm tensors from weight decay, matching typical transformer training best practices.

Parameters:
  • encoder (Module)

  • predictor (Module)

  • iterations_per_epoch (int)

  • start_lr (float)

  • ref_lr (float)

  • warmup (float)

  • num_epochs (int)

  • wd (float)

  • final_wd (float)

  • final_lr (float)

  • use_bfloat16 (bool)

  • ipe_scale (float)

Return type:

tuple

Benchmarks and reports#

Benchmarks sub-module - Benchmark runners and adapters for world models.

This package provides tools for running standardized evaluations of world models (Dreamer, IRIS, DIAMOND) across multiple seeds and computing aggregate metrics.

Usage:

from world_models.benchmarks import BenchmarkRunner, DiamondAdapter runner = BenchmarkRunner(adapter_cls=DiamondAdapter, …)

class world_models.benchmarks.runner.BenchmarkRunner(adapter_cls, out_dir='results')[source]#

Bases: object

Run evaluations for adapters across seeds and export results.

Usage:

runner = BenchmarkRunner(adapter_cls=adapters.DiamondAdapter) results = runner.run(games=[“Breakout-v5”], seeds=[0,1], episodes=5)

Parameters:
run(env_spec=None, seeds=None, num_episodes=5, checkpoint=None, extra_kwargs=None)[source]#

Run benchmark.

Returns a results dict with per-seed episode returns and computed metrics.

Parameters:
  • env_spec (Any | None)

  • seeds (List[int] | None)

  • num_episodes (int)

  • checkpoint (str | None)

  • extra_kwargs (Dict[str, Any] | None)

Return type:

Dict[str, Any]

class world_models.benchmarks.runner.MultiAgentBenchmarkRunner(adapter_classes, out_dir='results')[source]#

Bases: object

Run evaluations for multiple adapters on the same environment.

Usage:

runner = MultiAgentBenchmarkRunner(adapters=[adapters.DiamondAdapter, adapters.IRISAdapter]) results = runner.run_all(game=”Breakout-v5”, seeds=[0,1], episodes=5)

Parameters:
run_all(env_spec, seeds=None, num_episodes=5, checkpoints=None, extra_kwargs=None, train_epochs=None)[source]#

Run benchmarks for all adapters on the same environment.

Returns a results dict with results for each adapter.

Parameters:
  • env_spec (Dict[str, Any])

  • seeds (List[int] | None)

  • num_episodes (int)

  • checkpoints (Dict[str, str] | None)

  • extra_kwargs (Dict[str, Any] | None)

  • train_epochs (int | None)

Return type:

Dict[str, Any]

class world_models.benchmarks.adapters.BaseAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: object

Parameters:
  • env_spec (Any | None)

  • seed (int)

  • kwargs (Any)

load_checkpoint(path)[source]#
Parameters:

path (str)

Return type:

None

evaluate(num_episodes=1, render=False)[source]#

Return standardized output. Preferred format: dict with key ‘episode_returns’ -> List[float]

Parameters:
  • num_episodes (int)

  • render (bool)

Return type:

dict

class world_models.benchmarks.adapters.DiamondAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

  • kwargs (Any)

load_checkpoint(path)[source]#
Parameters:

path (str)

Return type:

None

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

Return type:

dict

class world_models.benchmarks.adapters.IRISAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

  • kwargs (Any)

load_checkpoint(path)[source]#
Parameters:

path (str)

Return type:

None

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

Return type:

dict

class world_models.benchmarks.adapters.DreamerAdapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: BaseAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

  • kwargs (Any)

load_checkpoint(path)[source]#
Parameters:

path (str)

Return type:

None

evaluate(num_episodes=1, render=False)[source]#
Parameters:
  • num_episodes (int)

  • render (bool)

Return type:

dict

class world_models.benchmarks.adapters.DreamerV1Adapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: DreamerAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

  • kwargs (Any)

class world_models.benchmarks.adapters.DreamerV2Adapter(env_spec=None, seed=0, **kwargs)[source]#

Bases: DreamerAdapter

Parameters:
  • env_spec (Any | None)

  • seed (int)

  • kwargs (Any)

world_models.benchmarks.metrics.compute_aggregate_metrics(per_seed_means)[source]#
Parameters:

per_seed_means (Iterable[float])

Return type:

Dict[str, float]

world_models.benchmarks.metrics.bootstrap_ci(values, num_samples=1000, alpha=0.05)[source]#

Compute simple bootstrap 1-alpha CI on the mean.

Parameters:
  • values (List[float])

  • num_samples (int)

  • alpha (float)

Return type:

tuple[float, float]

world_models.benchmarks.metrics.iqm_of_array(values)[source]#

Compute the Interquartile Mean (IQM) of an array of values.

IQM is the mean of values that lie between the 25th and 75th percentiles (inclusive). This is a robust central tendency measure used in RL benchmark reporting.

Parameters:

values (Iterable[float])

Return type:

float

world_models.benchmarks.metrics.bootstrap_iqm_ci(values, num_samples=1000, alpha=0.05)[source]#

Bootstrap a confidence interval for the IQM.

Returns (lower, upper) percentiles of the bootstrap IQM distribution.

Parameters:
  • values (List[float])

  • num_samples (int)

  • alpha (float)

Return type:

tuple[float, float]

world_models.benchmarks.reporting.export_csv(results, path)[source]#
Parameters:
  • results (Dict[str, Any])

  • path (str)

Return type:

None

world_models.benchmarks.reporting.export_markdown(results, path)[source]#
Parameters:
  • results (Dict[str, Any])

  • path (str)

Return type:

None

world_models.benchmarks.reporting.export_latex(results, path, caption='Benchmark results')[source]#
Parameters:
  • results (Dict[str, Any])

  • path (str)

  • caption (str)

Return type:

None

Utilities#

Loss functions for World Models training.

This module provides loss functions for training VAE and other world model components.

world_models.losses.convae_loss.conv_vae_loss_fn(reconst, x, mu, logsigma)[source]#

Compute the ConvVAE loss function.

The loss combines: 1. Reconstruction loss (MSE) between input and reconstructed images 2. KL divergence between learned latent distribution and prior (standard normal)

The total loss is: BCE + KLD

Parameters:
  • reconst (Tensor) – Reconstructed images from the VAE decoder.

  • x (Tensor) – Original input images.

  • mu (Tensor) – Mean of the latent distribution.

  • logsigma (Tensor) – Log variance of the latent distribution.

Returns:

Scalar tensor containing the total VAE loss.

Return type:

Tensor

Example

>>> recon_x, mu, logsigma = vae(images)
>>> loss = conv_vae_loss_fn(recon_x, images, mu, logsigma)
>>> loss.backward()

Gaussian Mixture Model (GMM) loss for MDRNN training.

This module provides the GMM loss function used in the Mixture Density Recurrent Neural Network (MDRNN) for world model training.

world_models.losses.gmm_loss.gmm_loss(latent_next_obs, mus, sigmas, logpi, reduce=True)[source]#

Compute the negative log-likelihood of a batch under a Gaussian Mixture Model.

This function computes minus the log probability of the batch under the GMM model described by mus, sigmas, and pi:

\[p(x) = \sum_k \pi_k \cdot \mathcal{N}(x \mid \mu_k, \sigma_k)\]

This is the loss function used in the MDRNN paper for predicting the next latent state.

Parameters:
  • latent_next_obs (Tensor) – (bs1, bs2, …, fs) Tensor containing the batch of target data.

  • mus (Tensor) – (bs1, bs2, …, gs, fs) Tensor of mixture means.

  • sigmas (Tensor) – (bs1, bs2, …, gs, fs) Tensor of mixture standard deviations.

  • logpi (Tensor) – (bs1, bs2, …, gs) Tensor of log mixture weights (log pi_k).

  • reduce (bool) – If True, mean over batch dimensions; otherwise return per-sample loss.

Returns:

scalar tensor with mean negative log-likelihood. If reduce is False: tensor with per-sample negative log-likelihoods.

Return type:

If reduce is True

Reference:

Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution.

Example

>>> batch = torch.randn(32, 10)
>>> mus = torch.randn(32, 10, 5, 10)
>>> sigmas = torch.randn(32, 10, 5, 10).exp()
>>> logpi = torch.randn(32, 10, 5).log_softmax(dim=-1)
>>> loss = gmm_loss(batch, mus, sigmas, logpi)

Training utilities for World Models.

This module provides utility classes for training neural networks including early stopping and learning rate scheduling.

class world_models.utils.train_utils.EarlyStopping(mode='min', patience=10, threshold=0.0001, threshold_mode='rel')[source]#

Bases: object

Early stopping handler to stop training when validation metric stops improving.

This class monitors a validation metric and stops training when no improvement is seen for a specified number of epochs (patience). This helps prevent overfitting and reduces unnecessary computation.

Parameters:
  • mode (str) – One of ‘min’ or ‘max’. In ‘min’ mode, training stops when the metric stops decreasing; in ‘max’ mode, when it stops increasing.

  • patience (int) – Number of epochs with no improvement after which to stop training.

  • threshold (float) – Minimum change to qualify as an improvement.

  • threshold_mode (str) – One of ‘rel’ or ‘abs’. In ‘rel’ mode, dynamic threshold is relative to best value; in ‘abs’ mode, it’s absolute.

stop#

Property that returns True if training should stop.

Example

>>> early_stopping = EarlyStopping(mode='min', patience=10)
>>> for epoch in range(100):
...     val_loss = validate()
...     early_stopping.step(val_loss)
...     if early_stopping.stop:
...         print(f"Stopped at epoch {epoch}")
...         break
step(metrics, epoch=None)[source]#

Update early stopping state with new metric value.

Parameters:
  • metrics (float) – Current epoch’s metric value.

  • epoch (int | None) – Current epoch number. If None, auto-increments from last epoch.

Return type:

None

property stop: bool#

True if training should stop due to no improvement.

Type:

bool

state_dict()[source]#

Get state dictionary for checkpointing.

Returns:

Dictionary containing early stopping state.

Return type:

dict

load_state_dict(state_dict)[source]#

Load state from checkpoint.

Parameters:

state_dict (dict) – Dictionary containing early stopping state.

Return type:

None

class world_models.utils.train_utils.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', min_lr=0, eps=1e-08)[source]#

Bases: object

Reduce learning rate when a metric stops improving.

This scheduler reduces the learning rate by a factor when a validation metric stops improving for a specified number of epochs. This helps models converge better by reducing the step size as they approach optimal weights.

Parameters:
  • optimizer (Optimizer) – The PyTorch optimizer to adjust.

  • mode (str) – One of ‘min’ or ‘max’. In ‘min’ mode, lr is reduced when metric stops decreasing; in ‘max’ mode, when it stops increasing.

  • factor (float) – Factor by which to reduce the learning rate.

  • patience (int) – Number of epochs with no improvement after which to reduce lr.

  • threshold (float) – Minimum change to qualify as an improvement.

  • threshold_mode (str) – One of ‘rel’ or ‘abs’.

  • min_lr (float) – Minimum learning rate to reduce to.

  • eps (float) – Minimum decay for lr.

lr#

Current learning rates for each parameter group.

Example

>>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
>>> scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
>>> for epoch in range(100):
...     train_loss = train()
...     val_loss = validate()
...     scheduler.step(val_loss)
...     if scheduler.stop:
...         break
step(metrics, epoch=None)[source]#

Update learning rate based on metric value.

Parameters:
  • metrics (float) – Current epoch’s metric value.

  • epoch (int | None) – Current epoch number. If None, auto-increments from last epoch.

Return type:

None

property lr: list#

Current learning rates for each parameter group.

Type:

list

state_dict()[source]#

Get state dictionary for checkpointing.

Returns:

Dictionary containing scheduler state.

Return type:

dict

load_state_dict(state_dict)[source]#

Load state from checkpoint.

Parameters:

state_dict (dict) – Dictionary containing scheduler state.

Return type:

None

world_models.utils.dreamer_utils.symlog(x)[source]#

Symmetric log transform used by Dreamer V2 for reward/value targets.

Defined as sign(x) * log(1 + |x|). This compresses large positive or negative values into a range that is easier to predict with a categorical distribution over a bounded set of buckets.

Parameters:

x (Tensor)

Return type:

Tensor

world_models.utils.dreamer_utils.symexp(x)[source]#

Inverse of symlog().

Defined as sign(x) * (exp(|x|) - 1).

Parameters:

x (Tensor)

Return type:

Tensor

class world_models.utils.dreamer_utils.TwoHotEncoder(num_buckets=255, symlog_range=10.0)[source]#

Bases: object

Two-hot encoding for symlog targets (Dreamer V2 reward/value heads).

A target value is softly assigned to the two nearest buckets on a uniform grid spanning [-symlog_range, symlog_range]. The categorical logits produced by a network can then be decoded back into a real value by computing the expected bucket center.

Parameters:
  • num_buckets (int) – Number of buckets in the categorical distribution.

  • symlog_range (float) – Maximum absolute value (in symlog space) covered by the grid. Values outside the range are clipped to the boundary buckets.

register_buffers()[source]#

Allocate the bucket-center buffer on CPU. Use to() to move.

Return type:

None

to(device)[source]#
Parameters:

device (device)

Return type:

TwoHotEncoder

encode(target)[source]#

Two-hot encode a real-valued target into soft bucket probabilities.

Parameters:

target (Tensor) – Tensor of arbitrary shape containing real-valued targets.

Returns:

Tensor with an extra final dimension of size num_buckets containing the soft two-hot distribution. The encoding assumes the target is already in symlog space, matching Dreamer V2.

Return type:

Tensor

decode(logits)[source]#

Decode categorical logits into the expected real-valued prediction.

The logits are first softmaxed and then combined with the bucket centers. The output is passed through symexp() to invert the symlog transform.

Parameters:

logits (Tensor) – Tensor with a final dimension of num_buckets.

Returns:

Tensor with the same shape as logits minus the last dimension.

Return type:

Tensor

world_models.utils.dreamer_utils.get_parameters(modules)[source]#

Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters

Parameters:

modules (Iterable[Module])

Return type:

list[Parameter]

class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]#

Bases: object

Context manager that temporarily disables gradients for given modules.

Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.Logger(log_dir, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', video_format='gif', video_fps=20, enable_tensorboard=False, enable_console=True, enable_jsonl=True, jsonl_filename='metrics.jsonl')[source]#

Bases: object

Experiment logger for scalars and GIF rollouts using WandB.

Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.

Parameters:
  • log_dir (str)

  • enable_wandb (bool)

  • wandb_api_key (str)

  • wandb_project (str)

  • wandb_entity (str)

  • video_format (str)

  • video_fps (int)

  • enable_tensorboard (bool)

  • enable_console (bool)

  • enable_jsonl (bool)

  • jsonl_filename (str)

log_scalar(scalar, name, step_)[source]#
Parameters:
  • scalar (Any)

  • name (str)

  • step_ (int)

Return type:

None

log_scalars(scalar_dict, step)[source]#
Parameters:
  • scalar_dict (dict[str, Any])

  • step (int)

Return type:

None

log_videos(videos, step, max_videos_to_save=1, fps=None, video_title='video')[source]#
Parameters:
  • videos (Any)

  • step (int)

  • max_videos_to_save (int)

  • fps (int | None)

  • video_title (str)

Return type:

None

dump_scalars_to_pickle(metrics, step, log_title=None)[source]#
Parameters:
  • metrics (dict[str, Any])

  • step (int)

  • log_title (str | None)

Return type:

None

flush()[source]#
Return type:

None

world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]#

Compute TD(lambda) returns from imagined rewards, values, and discounts.

Implements backward recursion used by Dreamer actor/value objectives.

Parameters:
  • rewards (Tensor)

  • values (Tensor)

  • discounts (Tensor)

  • td_lam (float)

  • last_value (Tensor)

Return type:

Tensor

world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]#

Initialize a tensor in-place from a truncated normal distribution.

Values are sampled from N(mean, std) and clipped to [a, b].

Parameters:
  • tensor (Tensor)

  • mean (float)

  • std (float)

  • a (float)

  • b (float)

Return type:

Tensor

world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]#

Repeat each batch chunk multiple times while preserving chunk ordering.

Used in JEPA masking code to align context and target token batches.

Parameters:
  • x (Tensor)

  • B (int)

  • repeat (int)

Return type:

Tensor

class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]#

Bases: object

Learning-rate schedule with linear warmup followed by cosine decay.

Updates optimizer parameter-group LRs on each call to step().

Parameters:
  • optimizer (Optimizer)

  • warmup_steps (int)

  • start_lr (float)

  • ref_lr (float)

  • T_max (int)

  • last_epoch (int)

  • final_lr (float)

step()[source]#
Return type:

float

class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]#

Bases: object

Cosine scheduler for optimizer weight decay values.

Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.

Parameters:
  • optimizer (Optimizer)

  • ref_wd (float)

  • T_max (int)

  • final_wd (float)

step()[source]#
Return type:

float

world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]#

Measure CUDA execution time for a closure and return (result, elapsed_ms).

Falls back to -1 elapsed time when CUDA timing is unavailable.

Parameters:
  • closure (Any)

  • log_timings (bool)

Return type:

Tuple[Any, float]

class world_models.utils.jepa_utils.CSVLogger(fname, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', *argv)[source]#

Bases: object

Lightweight CSV logger with per-column printf-style formatting and WandB support.

Parameters:
  • fname (str)

  • enable_wandb (bool)

  • wandb_api_key (str)

  • wandb_project (str)

  • wandb_entity (str)

  • argv (Any)

log(step, *argv)[source]#
Parameters:
  • step (int)

  • argv (Any)

Return type:

None

class world_models.utils.jepa_utils.AverageMeter[source]#

Bases: object

Track running statistics (val, avg, min, max, sum, count) for metrics.

reset()[source]#
Return type:

None

update(val, n=1)[source]#
Parameters:
  • val (float)

  • n (int)

Return type:

None

world_models.utils.jepa_utils.grad_logger(named_params)[source]#

Aggregate gradient norm statistics over model parameters for logging.

Also exposes first/last qkv-layer gradient norms when available.

Parameters:

named_params (Any)

Return type:

AverageMeter

world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]#

Initialize torch distributed process groups when environment supports it.

Returns (world_size, rank) and gracefully falls back to single-process mode.

Parameters:
  • port (int)

  • rank_and_world_size (tuple)

Return type:

Tuple[int, int]

class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]#

Bases: Function

Autograd-aware all-gather operation across distributed workers.

Forward concatenates worker tensors; backward reduces and slices gradients.

static forward(ctx, x)[source]#
Parameters:
  • ctx (Any)

  • x (Tensor)

Return type:

Tensor

static backward(ctx, grads)[source]#
Parameters:
  • ctx (Any)

  • grads (Tensor)

Return type:

Tensor

class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]#

Bases: Function

Autograd function that sums tensors across distributed workers in forward pass.

static forward(ctx, x)[source]#
Parameters:
  • ctx (Any)

  • x (Tensor)

Return type:

Tensor

static backward(ctx, grads)[source]#
Parameters:
  • ctx (Any)

  • grads (Tensor)

Return type:

Tensor

class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]#

Bases: Function

Autograd function that all-reduces and averages tensors across workers.

Used to synchronize scalar losses for consistent distributed logging/training.

static forward(ctx, x)[source]#
Parameters:
  • ctx (Any)

  • x (Tensor)

Return type:

Tensor

static backward(ctx, grads)[source]#
Parameters:
  • ctx (Any)

  • grads (Tensor)

Return type:

Tensor

world_models.utils.data_utils.create_efficient_dataloader(dataset, batch_size, num_workers=None, pin_memory=True, prefetch_factor=2, persistent_workers=True)[source]#

Create a memory-efficient and fast DataLoader.

Parameters:
  • dataset (Dataset)

  • batch_size (int)

  • num_workers (int | None)

  • pin_memory (bool)

  • prefetch_factor (int)

  • persistent_workers (bool)

Return type:

DataLoader

world_models.utils.data_utils.prefetch_iterator(iterator, buffer_size=3)[source]#

Add prefetching to any iterator.

Parameters:
  • iterator (Iterator)

  • buffer_size (int)

Return type:

Iterator

world_models.utils.jit_utils.jit_compile_function(func)[source]#

JIT compile a function for performance.

Parameters:

func (Callable)

Return type:

Callable

world_models.utils.jit_utils.jit_compile_module(module)[source]#

JIT compile a PyTorch module.

Parameters:

module (Module)

Return type:

Module

world_models.utils.memory_utils.apply_gradient_checkpointing(model, checkpoint_ratio=0.5)[source]#

Apply gradient checkpointing to reduce memory usage during training.

Parameters:
  • model (Module)

  • checkpoint_ratio (float)

Return type:

None

world_models.utils.memory_utils.enable_mixed_precision(model, scaler=None)[source]#

Enable mixed precision training.

Parameters:
  • model (Module)

  • scaler (GradScaler | None)

Return type:

GradScaler

world_models.utils.memory_utils.optimize_memory_efficient_ops()[source]#

Set PyTorch for memory-efficient operations.

Return type:

None

Logging, metrics, and numerical-safety helpers for torchwm.

world_models.utils.logging_utils.get_package_logger(name=None)[source]#

Return a logger under the world_models package namespace.

Parameters:

name (str | None)

Return type:

Logger

world_models.utils.logging_utils.setup_logging(name='world_models', level='INFO', log_file=None, fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s')[source]#

Set up structured package logging with optional file output.

Parameters:
  • name (str) – Logger name to configure. Defaults to the package logger.

  • level (str | int) – Logging level name or numeric value.

  • log_file (str | None) – Optional file path for a file handler.

  • fmt (str) – logging.Formatter format string.

Return type:

Logger

class world_models.utils.logging_utils.MetricsLogger(log_dir, *, logger=None, enable_console=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', enable_tensorboard=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', run_name=None)[source]#

Bases: object

Fan-out metric logger for console, JSONL, TensorBoard, and W&B.

JSONL output is enabled by default because it is dependency-free and easy to reload for offline plots. TensorBoard and W&B are optional and activated only when requested and available/configured.

Parameters:
  • log_dir (str)

  • logger (logging.Logger | None)

  • enable_console (bool)

  • enable_jsonl (bool)

  • jsonl_filename (str)

  • enable_tensorboard (bool)

  • enable_wandb (bool)

  • wandb_api_key (str)

  • wandb_project (str)

  • wandb_entity (str)

  • run_name (str | None)

log(metrics, step, prefix=None)[source]#

Log scalar metrics to every enabled sink.

Parameters:
  • metrics (Mapping[str, Any])

  • step (int)

  • prefix (str | None)

Return type:

dict[str, Any]

log_video(name, video, step, fps=20)[source]#

Log a video to TensorBoard and W&B when enabled.

Parameters:
  • name (str)

  • video (Any)

  • step (int)

  • fps (int)

Return type:

None

flush()[source]#
Return type:

None

close()[source]#
Return type:

None

world_models.utils.logging_utils.collect_system_stats(device=None)[source]#

Collect CPU/GPU memory and CUDA utilization counters when available.

Parameters:

device (device | str | None)

Return type:

dict[str, float]

world_models.utils.logging_utils.assert_finite_values(value, name='value')[source]#

Raise FloatingPointError if any tensor contains NaN or Inf.

Parameters:
  • value (Any)

  • name (str)

Return type:

Any

world_models.utils.logging_utils.assert_finite(fn)[source]#

Decorator that validates tensor outputs from loss functions are finite.

Parameters:

fn (Any)

Return type:

Any

class world_models.utils.utils.AttrDict[source]#

Bases: dict

world_models.utils.utils.load_yml_config(path)[source]#
Parameters:

path (str)

Return type:

AttrDict | None

world_models.utils.utils.to_tensor_obs(image)[source]#

Converts the input np img to channel first 64x64 dim torch img.

Parameters:

image (ndarray)

Return type:

Tensor

world_models.utils.utils.postprocess_img(image, depth)[source]#

Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])

Parameters:
  • image (ndarray)

  • depth (int)

Return type:

ndarray

world_models.utils.utils.preprocess_img(image, depth)[source]#

Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !!

Parameters:
  • image (Tensor)

  • depth (int)

Return type:

None

world_models.utils.utils.bottle(func, *tensors)[source]#

Evaluates a func that operates in N x D with inputs of shape N x T x D

Parameters:
  • func (Any)

  • tensors (Tensor)

Return type:

Tensor

world_models.utils.utils.get_combined_params(*models)[source]#

Returns the combine parameter list of all the models given as input.

Parameters:

models (Any)

Return type:

list

world_models.utils.utils.save_video(frames, path, name)[source]#

Saves a video containing frames.

Accepts frames in either:
  • (T, C, H, W) float in [0,1]

  • (T, H, W, C) float in [0,1]

Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout.

Parameters:
  • frames (Any)

  • path (str)

  • name (str)

Return type:

str

world_models.utils.utils.combine_videos(video_dir, output_name='combined.mp4', pattern='vid_*.mp4', fps=25, resize=True)[source]#

Combine all videos matching pattern in video_dir into a single MP4 file. Returns the output filepath (string).

Example

combine_videos(“results/planet”, output_name=”all_training.mp4”)

Parameters:
  • video_dir (str)

  • output_name (str)

  • pattern (str)

  • fps (int)

  • resize (bool)

Return type:

str

world_models.utils.utils.ensure_results_dir_exists(results_dir)[source]#

Simple helper to validate a results directory exists. Raises FileNotFoundError if not present.

Parameters:

results_dir (str)

Return type:

None

world_models.utils.utils.save_frames(target, pred_prior, pred_posterior, name, n_rows=5)[source]#

Save side-by-side target, prior-prediction, and posterior-prediction frames.

The function accepts tensors with optional time dimension and writes a PNG grid to {name}.png. Spatial sizes are aligned per timestep before concatenation and values are normalized to [0, 1] when needed.

Parameters:
  • target (Tensor)

  • pred_prior (Tensor)

  • pred_posterior (Tensor)

  • name (str)

  • n_rows (int)

Return type:

None

world_models.utils.utils.get_mask(tensor, lengths)[source]#

Build a batch-first validity mask from sequence lengths.

tensor may be a tensor/array with shape (N, T, ...) or (N,). The returned mask marks valid timesteps with ones up to each element in lengths and preserves device/dtype conventions from the input.

Parameters:
  • tensor (Any)

  • lengths (Any)

Return type:

Tensor

world_models.utils.utils.load_memory(path, device, *, trusted=False)[source]#

Loads an experience replay buffer.

Pickle can execute arbitrary code during unrestricted deserialization, so user-supplied replay buffers are always loaded with a restricted unpickler that only allows the replay buffer classes and numpy containers required by historical buffers. The trusted argument is retained for backwards compatibility, but it no longer enables unrestricted pickle loading.

Converts legacy list/.data formats into the current Memory(episodes) object.

Parameters:
  • path (str)

  • device (device)

  • trusted (bool)

Return type:

Any

world_models.utils.utils.flatten_dict(data, sep='.', prefix='')[source]#

Flattens a nested dict into a single-level dict.

Example

{‘a’: 2, ‘b’: {‘c’: 20}} -> {‘a’: 2, ‘b.c’: 20}

Parameters:
  • data (dict)

  • sep (str)

  • prefix (str)

Return type:

dict

world_models.utils.utils.normalize_frames_for_saving(frames)[source]#

Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.

Parameters:

frames (Any)

Return type:

ndarray

class world_models.utils.utils.TensorBoardMetrics(path)[source]#

Bases: object

Plots and (optionally) stores metrics for an experiment.

Parameters:

path (str)

assign_type(key, val)[source]#
Parameters:
  • key (str)

  • val (Any)

Return type:

None

update(metrics)[source]#
Parameters:

metrics (dict)

Return type:

None

world_models.utils.utils.apply_model(model, inputs, ignore_dim=None)[source]#

Placeholder helper for generic model application across input structures.

Currently not implemented; kept as an extension hook for future utility code.

Parameters:
  • model (Any)

  • inputs (Any)

  • ignore_dim (Any)

Return type:

None

world_models.utils.utils.plot_metrics(metrics, path, prefix)[source]#

Render and save line plots for each metric series in a dictionary.

Parameters:
  • metrics (dict)

  • path (str)

  • prefix (str)

Return type:

None

world_models.utils.utils.lineplot(xs, ys, title, path='', xaxis='episode')[source]#

Create a Plotly line plot for scalar, dict, or ensemble-series data.

Supports uncertainty-band plotting when ys is a 2D array.

Parameters:
  • xs (ndarray | list)

  • ys (Any)

  • title (str)

  • path (str)

  • xaxis (str)

Return type:

None

class world_models.utils.utils.TorchImageEnvWrapper(env, bit_depth, observation_shape=None, act_rep=2)[source]#

Bases: object

Torch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form.

Parameters:
  • env (Any)

  • bit_depth (int)

  • observation_shape (Any)

  • act_rep (int)

reset()[source]#
Return type:

Tensor

step(u)[source]#
Parameters:

u (Tensor | ndarray | list | float | int)

Return type:

tuple

render()[source]#
Return type:

None

close()[source]#
Return type:

None

property observation_size: tuple[int, int, int]#
property action_size: int#
sample_random_action()[source]#
Return type:

Tensor

property max_episode_steps: int#

Return environment max episode steps (compatible with TimeLimit/spec).

world_models.utils.utils.apply_masks(x, masks)[source]#

Gather token subsets from patch sequences using index masks.

Each mask selects token positions from x; selected groups are concatenated along the batch dimension.

Parameters:
  • x (Tensor)

  • masks (list[Tensor])

Return type:

Tensor

world_models.utils.utils.visualize_latent_tsne(latents, labels=None, save_path=None, perplexity=30)[source]#

Visualize latent representations using t-SNE.

Parameters:
  • latents (Tensor | ndarray) – torch.Tensor of shape (N, D) or numpy array

  • labels (ndarray | None) – optional list or array of labels for coloring

  • save_path (str | None) – path to save the plot (HTML for plotly)

  • perplexity (int) – t-SNE perplexity parameter

Return type:

Figure

world_models.utils.utils.visualize_latent_umap(latents, labels=None, save_path=None, n_neighbors=15)[source]#

Visualize latent representations using UMAP.

Parameters:
  • latents (Tensor | ndarray) – torch.Tensor of shape (N, D) or numpy array

  • labels (ndarray | None) – optional list or array of labels for coloring

  • save_path (str | None) – path to save the plot (HTML for plotly)

  • n_neighbors (int) – UMAP n_neighbors parameter

Return type:

Figure

class world_models.utils.utils.StreamingVideoWriter(path, fps=20, frame_shape=None, format='mp4')[source]#

Bases: object

A class for streaming video writing to save frames in real-time.

Parameters:
  • path (str) – output video file path

  • fps (int) – frames per second

  • frame_shape (Any) – (height, width) of frames

  • format (str) – ‘mp4’ or ‘avi’

write_frame(frame)[source]#

Write a single frame to the video.

Parameters:

frame (ndarray) – numpy array of shape (H, W, C) or (H, W), uint8 or float in [0,1]

Return type:

None

close()[source]#
Return type:

None