API Reference#
This reference is generated from source docstrings and grouped by workflow. Use World Models Study Guide for conceptual explanations and this page for exact classes, functions, and module-level APIs.
Public package surface#
These modules expose the most common imports and lazy constructors.
Primary module: torchwm. Implementation modules are documented below for API completeness.
Use torchwm for common workflows:
import torchwm
agent = torchwm.create_model("dreamer", env="walker-walk")
Friendly top-level package for TorchWM.
torchwm is the recommended public namespace for users:
import torchwm
agent = torchwm.create_model("dreamer", env="walker-walk")
- class torchwm.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing an environment backend available through
make_env.- Parameters:
name (str)
factory_path (str)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- factory_path: str#
Alias for field number 1
- description: str#
Alias for field number 2
- aliases: tuple[str, ...]#
Alias for field number 3
- class torchwm.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing a model available through
create_model().- Parameters:
name (str)
import_path (str)
config_path (str | None)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- import_path: str#
Alias for field number 1
- config_path: str | None#
Alias for field number 2
- description: str#
Alias for field number 3
- aliases: tuple[str, ...]#
Alias for field number 4
- torchwm.create_config(model, **overrides)[source]#
Create the default config object for
modeland apply overrides.Examples
>>> cfg = create_config("dreamer", env="walker-walk", seed=7) >>> cfg.env 'walker-walk'
- Parameters:
model (str)
overrides (Any)
- Return type:
Any
- torchwm.create_model(model, config=None, **overrides)[source]#
Instantiate a model or agent from a simple string name.
configis optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.Examples
>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000) >>> genie = create_model("genie-small", image_size=32)
- Parameters:
model (str)
config (Any | None)
overrides (Any)
- Return type:
Any
- torchwm.get_env_backend_spec(name)[source]#
Return metadata for an environment backend name or alias.
- Parameters:
name (str)
- Return type:
- torchwm.get_model_spec(name)[source]#
Return metadata for a model name or alias.
- Parameters:
name (str)
- Return type:
- torchwm.list_env_backends()[source]#
Return canonical backend names accepted by
make_env().- Return type:
list[str]
- torchwm.list_envs(model=None)[source]#
List known environment ids, optionally filtered by model family.
- Parameters:
model (str | None)
- Return type:
list[str] | dict[str, list[str]]
- torchwm.list_models()[source]#
Return canonical model names accepted by
create_model().- Return type:
list[str]
- torchwm.make_env(env_id, backend='auto', **kwargs)[source]#
Create an environment with a consistent TorchWM entry point.
- Parameters:
env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.
backend (str) – One of
list_env_backends();"auto"tries TorchWM’s compatibility helper.**kwargs (Any) – Backend-specific options.
- Return type:
Any
- class torchwm.IRISAgent(config, action_size, device)[source]#
Bases:
ModuleComplete IRIS Agent with world model and policy.
Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning
- Parameters:
config (IRISConfig)
action_size (int)
device (device)
- forward_actor_critic(frames, hidden=None)[source]#
Forward pass through actor-critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state
- Returns:
(B, T, action_size) values: (B, T) hidden_state: (h, c)
- Return type:
action_logits
- act(frame, epsilon=0.0, temperature=1.0)[source]#
Sample action from policy.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
epsilon (float) – Random action probability
temperature (float) – Action distribution temperature
- Returns:
Selected actions (B,)
- Return type:
- imagine_rollout(initial_frame, horizon=20)[source]#
Generate imagined trajectories using world model.
- Parameters:
initial_frame (Tensor) – Starting frame (B, C, H, W)
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined rollout data
- Return type:
trajectory
- update_autoencoder(frames)[source]#
Update discrete autoencoder.
- Parameters:
frames (Tensor) – Training frames (B, C, H, W)
- Returns:
Dictionary of loss values
- Return type:
losses
- update_transformer(frames, actions, rewards, terminals)[source]#
Update transformer world model.
- Parameters:
frames (Tensor) – Frame sequence
actions (Tensor) – Actions taken
rewards (Tensor) – Rewards received
terminals (Tensor) – Terminal flags
- Returns:
Dictionary of loss values
- Return type:
losses
- torchwm.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#
Compute λ-return target for value function training.
- Parameters:
rewards (Tensor) – Rewards (B, T)
values (Tensor) – Value estimates (B, T+1)
discounts (Tensor) – Discount factors (B, T)
lambda_coef (float) – Lambda parameter for bootstrapping
- Returns:
λ-return targets (B, T)
- Return type:
- class torchwm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#
Bases:
ModuleModular RSSM with swappable encoder, decoder, and backbone.
This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer
Example
>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024) >>> decoder = ConvDecoder(32, 200, (3, 64, 64)) >>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024) >>> rssm = ModularRSSM(encoder, decoder, backbone)
- Parameters:
encoder (EncoderBase)
decoder (DecoderBase)
backbone (BackboneBase)
reward_decoder (DecoderBase | None)
- property stoch_size: int#
- property deter_size: int#
- property embed_size: int#
- init_state(batch_size, device)[source]#
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
obs (Tensor)
nonterm (Any)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
nonterm (Any)
- Return type:
Dict[str, Tensor]
- observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
- Parameters:
obs (Tensor)
actions (Tensor)
nonterms (Tensor)
prev_state (Dict[str, Tensor])
horizon (int)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- torchwm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#
Factory function to create a modular RSSM with specified components.
- Parameters:
encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)
decoder_type (str) – Type of decoder (“conv”, “mlp”)
backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)
obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state
action_size (int) – Action space dimension
stoch_size (int) – Stochastic latent dimension
deter_size (int) – Deterministic hidden dimension
embed_size (int) – Encoder embedding dimension
hidden_size (int) – Hidden layer dimension
activation (str) – Activation function name
- Returns:
Configured ModularRSSM instance
- Return type:
- class torchwm.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleGenie: Generative Interactive Environment.
A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions
Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).
Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)
The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse
At inference, latent actions are stopgrad’d when passed to dynamics model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_decoder_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
encoder_depth (int)
decoder_depth (int)
latent_action_depth (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- forward(video, mask_prob=0.5, training_phase='all')[source]#
Full forward pass through all components.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing losses and predictions
- Return type:
Dict[str, Tensor]
- training_step(video, mask_prob=0.5, training_phase='all')[source]#
Single training step computing all losses.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing all losses for backpropagation
- Return type:
Dict[str, Tensor]
- encode_video(video)[source]#
Encode video to discrete tokens.
- Parameters:
video (Tensor) – (B, C, T, H, W)
- Returns:
(B, T, H*W)
- Return type:
video_tokens
- infer_actions(frames)[source]#
Infer latent actions from a sequence of frames.
- Parameters:
frames (Tensor) – (B, C, T, H, W) video frames
- Returns:
(B, T-1) inferred latent action indices
- Return type:
latent_actions
- generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#
Generate video frames given a prompt frame and actions.
- Parameters:
prompt_frame (Tensor) – (B, C, H, W) initial frame
num_frames (int) – Total number of frames to generate
actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random
use_maskgit (bool) – Whether to use MaskGIT sampling
- Returns:
(B, C, num_frames, H, W)
- Return type:
generated_video
- play(current_frame, action, current_frames=None)[source]#
Play step - generate next frame given current frame and action.
- Parameters:
current_frame (Tensor) – (B, C, H, W) current frame
action (Tensor) – (B,) latent action indices
current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame
- Returns:
(B, C, H, W)
- Return type:
next_frame
- class torchwm.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleLatent Action Model (LAM) for unsupervised action learning.
Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.
Based on Genie paper - learns actions without action labels from Internet videos.
Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- encode(x_prev, x_next)[source]#
Encode frames to latent actions.
- Parameters:
x_prev (Tensor) – Previous frames (B, C, T, H, W)
x_next (Tensor) – Next frame (B, C, H, W)
- Returns:
Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)
- Return type:
latent_actions
- class torchwm.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
ModuleDynamics Model for action-controllable video generation.
A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.
Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- forward(video_tokens, actions, mask_prob=0.0)[source]#
Forward pass for training.
- Parameters:
video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T
actions (Tensor) – (B, T) - latent action indices for frames 1 to T
mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)
- Returns:
(B, T, H*W, vocab_size)
- Return type:
logits
- sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#
Sample future frames using MaskGIT.
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
sampler (MaskGITSampler | None) – MaskGIT sampler instance
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#
Simple autoregressive sampling (token by token).
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
temperature (float) – Sampling temperature
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- torchwm.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Create a smaller Genie model for development/testing.
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#
Create the full 11B parameter Genie model (approximate).
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#
Factory function to create a Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
- Return type:
- class torchwm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#
Bases:
ModuleRecurrent State-Space Model used by Dreamer for latent dynamics learning.
The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:
Deterministic State (h): A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.
Stochastic State (s): A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).
The model operates in two modes:
Observe Mode: Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)
Imagine Mode: Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)
- Architecture:
Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1} Process: GRU updates deterministic state, MLP computes stochastic prior/posterior Output: Updated state (h_t, s_t) and distributions
- State Representation:
deter (h): GRU hidden state, captures sequential context
stoch (s): Stochastic latent, multi-modal uncertainty
mean/std: Parameters of the stochastic distribution
- Usage with DreamerAgent:
- rssm = RSSM(
action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation=’elu’
)
# Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed)
# Imagine future without observation prior = rssm.imagine_step(current_state, action)
- Training:
The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to
capture environment dynamics
Reconstruction loss from decoder ensures state captures observation info
- Reference:
Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603
- init_state(batch_size, device)[source]#
Initialize RSSM state with zeros.
- Parameters:
batch_size – Number of parallel sequences
device – torch device for tensors
- Returns:
mean, std: Stochastic distribution parameters
stoch: Stochastic state sample
deter: Deterministic GRU hidden state
- Return type:
Dictionary containing zero-initialized state components
- get_dist(mean, std)[source]#
Create an Independent Normal distribution from mean and std.
- Parameters:
mean – Location parameter
std – Scale parameter
- Returns:
Independent Normal distribution with given parameters
- observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
Update state using actual observation (observe mode).
In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.
- Parameters:
prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action – Previous action a_{t-1}, shape (B, action_size)
obs_embed – Observation embedding from encoder, shape (B, obs_embed_size)
nonterm – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
A tuple
(posterior, prior)of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Predict next state without observation (imagine mode).
In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.
- Parameters:
prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action – Previous action a_{t-1}, shape (B, action_size)
nonterm – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
deter: Predicted deterministic state
mean, std, stoch: Prior stochastic state distribution
- Return type:
Dictionary with predicted state containing
- get_prior(prev_state, prev_action, nonterm=1.0)[source]#
Compute prior distribution over stochastic state.
The prior represents the model’s belief about the stochastic state before observing the actual outcome.
- Parameters:
prev_state – Previous state dictionary
prev_action – Previous action
nonterm – Termination mask
- Returns:
Dictionary with prior state (no observation information)
- get_posterior(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
Compute posterior distribution over stochastic state.
The posterior incorporates observation information to produce a more accurate state estimate.
- Parameters:
prev_state – Previous state dictionary
prev_action – Previous action
obs_embed – Observation embedding
nonterm – Termination mask
- Returns:
Dictionary with posterior state (observation-informed). Note that the previous-state shape
(B, ...)is preserved; the batch dimension is not flattened.
- detach_state(state)[source]#
Detach state tensors from computation graph.
Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.
- Parameters:
state – State dictionary with tensor values
- Returns:
Detached state dictionary
- seq_to_batch(state_dict)[source]#
Convert sequence state to batch format.
- Parameters:
state_dict – Dictionary with sequence-dimension tensors (T, B, …)
- Returns:
Dictionary with batch-dimension tensors (B*T, …)
- observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#
Process a sequence of observations (observe mode rollout).
At each timestep we run
observe_steponce to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.- Parameters:
obs_embed – Observation embeddings, shape (T+1, B, obs_embed_size)
actions – Actions, shape (T, B, action_size)
nonterms – Non-termination flags, shape (T, B, 1)
init_state – Initial state dictionary
seq_len – Sequence length T
- Returns:
Dictionary with prior states stacked along the time axis. posterior: Dictionary with posterior states stacked along the time
axis.
- Return type:
prior
- imagine_rollout(policy, init_state, horizon)[source]#
Generate imagined trajectory using policy (imagine mode rollout).
- Parameters:
policy – Actor network that outputs actions from state features
init_state – Initial state dictionary
horizon – Number of steps to imagine
- Returns:
Dictionary with imagined states for each step
- forward(x, u)[source]#
Forward pass for training (computes sequence of states).
- Parameters:
x – Observations, shape (B, T+1, C, H, W)
u – Actions, shape (B, T, action_size)
- Returns:
List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)
- Return type:
states
- class torchwm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#
Bases:
ModuleA Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.
- get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#
Returns the initial posterior given the observation.
- deterministic_state_fwd(h_t, s_t, a_t)[source]#
- Deterministic transition update that accepts:
a_t shaped [B, action_size]
a_t shaped [action_size] (unbatched) -> expanded to [B, action_size]
a_t shaped [B] or scalar -> reshaped appropriately
Ensures a_t is 2D and matches batch dimension of h_t before concatenation.
- state_prior(h_t, sample=False)[source]#
Returns the prior distribution over the latent state given the deterministic state
- state_posterior(h_t, e_t, sample=False)[source]#
Returns the state prior given the deterministic state and obs
- forward(x, u)[source]#
Forward through the RSSM for a batch of sequences. Inputs:
x: Tensor [B, T+1, C, H, W] (observations including initial frame) u: Tensor [B, T, action_size] (actions for T steps)
- Returns:
list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]
- Return type:
states
- class torchwm.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).
- Architecture:
Input: (B, C, H, W) raw images, values in [-0.5, 0.5] Process: 4 convolutional layers with stride 2, halving spatial dimensions Output: (B, embed_size) compact representation
The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.
- Usage with Dreamer:
- encoder = ConvEncoder(
input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation=’relu’ # Activation function
) obs_embedding = encoder(observation) # (B, 256)
- Parameters:
input_shape – Tuple (C, H, W) for input images, typically (3, 64, 64)
embed_size – Output embedding dimension, typically 256 or 1024
activation – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)
depth – Base channel depth for first layer (default 32)
- class torchwm.CNNEncoder(embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) encoder for processing image inputs.
- class torchwm.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#
Bases:
ModuleConvolutional decoder for reconstructing observations from latent states.
Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.
- Architecture:
Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter) Process: Dense projection + 4 transposed convolutions (upsampling 2x each) Output: Independent Normal distribution over observation pixels
The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.
- Output Distribution:
Returns torch.distributions.Independent(Normal(mean, std), len(shape)) This allows computing log_prob(observation) for reconstruction loss.
- Usage in Dreamer world model:
- decoder = ConvDecoder(
stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation=’relu’
) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)
- Training:
The reconstruction loss is: -log_prob(observation) This encourages the RSSM to learn states that capture observation information.
- class torchwm.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) decoder for reconstructing image outputs.
- class torchwm.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.
- Architecture:
Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter) Process: MLP with configurable layers and hidden units Output: Predicted quantity with distribution (normal, binary, or raw)
- Supports three output types:
‘normal’: Gaussian distribution for regression (rewards, values)
‘binary’: Bernoulli distribution for binary classification (discount)
‘none’: Raw tensor for non-probabilistic outputs
- Usage:
- reward_decoder = DenseDecoder(
stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’normal’
) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)
- For discount prediction (binary):
- discount_decoder = DenseDecoder(
stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’binary’ # Bernoulli for P(continue)
)
- class torchwm.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
- class torchwm.TanhBijector[source]#
Bases:
TransformBijective tanh transform for squashing Gaussian distributions to [-1, 1].
This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:
Bijective mapping: tanh is invertible (with atanh as inverse)
Stable log-det Jacobian: Computable for gradient-based training
Clipped actions: During inference, actions are naturally bounded
- Math:
Forward: y = tanh(x) Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y)) Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))
- Usage with Dreamer ActionDecoder:
- dist = TransformedDistribution(
Normal(mean, std), TanhBijector()
) action = dist.sample() # Bounded to [-1, 1]
- Reference:
Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.
- property sign#
- class torchwm.SampleDist(dist, samples=100)[source]#
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- property name#
- class torchwm.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Encoder for IRIS discrete autoencoder.
Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.
- Architecture:
4 convolutional layers with residual blocks
Self-attention at 8x8 and 16x16 resolutions
Vector quantization to produce discrete tokens
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
in_channels (int)
base_channels (int)
num_residual_blocks (int)
frame_shape (Tuple[int, int, int])
- forward(x)[source]#
Encode images to discrete tokens.
- Parameters:
x (Tensor) – Input images (B, C, H, W) - should be 64x64
- Returns:
Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- class torchwm.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Decoder for IRIS discrete autoencoder.
Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.
- Parameters:
vocab_size (int)
embedding_dim (int)
base_channels (int)
out_channels (int)
frame_shape (Tuple[int, int, int])
- forward(z)[source]#
Decode tokens to images.
- Parameters:
z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)
- Returns:
Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)
- Return type:
reconstructed
- class torchwm.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#
Bases:
ModuleVideo Tokenizer using VQ-VAE with Spatiotemporal Transformer.
This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.
The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.
- Architecture:
Patch Embedding: Convert (B, C, T, H, W) video to patch tokens
Encoder ST-Transformer: Process spatial-temporal patches
Vector Quantization: Discretize continuous embeddings to codebook entries
Decoder ST-Transformer: Reconstruct video from quantized tokens
Patch Unembedding: Convert tokens back to video frames
- Key Features:
Causal processing: Each frame’s encoding only uses previous frames
Discrete tokens: Enables autoregressive prediction with latent actions
Memory efficient: Uses ST-Transformer instead of full ViT to reduce O(n²) complexity
- Usage with Genie:
- tokenizer = VideoTokenizer(
num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32
) reconstructed, indices, loss_dict = tokenizer(video_frames)
# For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)
- Training:
The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings (encourages learning useful codes) - Commitment loss: Penalizes encoder outputs drifting from codebook
- Reference:
Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
use_ema (bool)
ema_decay (float)
- encode(x)[source]#
Encode video to discrete tokens.
- Parameters:
x (Tensor) – Video tensor (B, C, T, H, W)
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- decode_indices(indices)[source]#
Decode token indices to embeddings for video frames.
- Parameters:
indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’*W’
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim)
- Return type:
z_q
- torchwm.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#
Factory function to create a Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
- Return type:
- class torchwm.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#
Bases:
ModuleVector Quantizer for discrete autoencoder.
Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)
Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
- class torchwm.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#
Bases:
ModuleVector Quantizer with Exponential Moving Average updates.
Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
ema_decay (float)
epsilon (float)
- class torchwm.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#
Bases:
objectFixed-size replay buffer for Dreamer with image observations and transitions.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.
- Key Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores raw uint8 images to save memory
Samples sequences (not single transitions) for temporal modeling
Validates sampled sequences don’t span episode boundaries
- Memory Layout:
observations: (capacity, C, H, W) uint8 images
actions: (capacity, action_dim) float32
rewards: (capacity,) float32
terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)
- Sampling Process:
Random start index (avoiding episode boundaries)
Collect sequence of length seq_len with wraparound
Validate no terminal in middle of sequence
Return batch of sequences
- Usage with Dreamer:
- buffer = ReplayBuffer(
size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch
)
# Add transitions during interaction buffer.add(obs, action, reward, done)
# Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample() # Shapes: (seq_len, batch, C, H, W), (seq_len, batch, action_dim), etc.
- Memory Efficiency:
Uses uint8 for images (1 byte per pixel vs 4 for float32)
Sequences share observations (overlapping windows)
Configurable capacity based on available system memory
Note
The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.
- Parameters:
size (int)
obs_shape (Tuple[int, ...])
action_size (int)
seq_len (int)
batch_size (int)
- add(obs, ac, rew, done)[source]#
Add a transition to the buffer.
- Parameters:
obs (dict) – Observation dict with ‘image’ key containing the observation
ac (ndarray) – Action taken, shape (action_size,)
rew (float) – Reward received, scalar
done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise
- Return type:
None
- class torchwm.Memory(size=None)[source]#
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.
- Features:
Stores complete episodes as lists of transitions
Samples contiguous sub-sequences for sequence models
Supports time-major formatting (time-first) for RNN input
Memory usage estimation to prevent OOM errors
- Parameters:
size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).
- episodes#
Collection of Episode objects.
- Type:
deque
- eps_lengths#
Length of each episode.
- Type:
deque
- size#
Total number of transitions across all episodes.
- Type:
property
Example
>>> memory = Memory(size=100) >>> memory.append([episode1, episode2]) >>> batch, lengths = memory.sample(batch_size=32, tracelen=50)
- property size#
- sample(batch_size, tracelen=1, time_first=False)[source]#
Sample random sub-sequences from stored episodes.
Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.
- Parameters:
batch_size (int) – Number of sequences to sample.
tracelen (int) – Length of each sequence (default: 1).
time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).
- Returns:
- (observations, actions, rewards, terminals, lengths)
observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)
actions: (batch, tracelen, action_dim) or (tracelen, batch, …)
rewards: (batch, tracelen) or (tracelen, batch)
terminals: (batch, tracelen) or (tracelen, batch)
lengths: (batch,) original episode lengths for each sample
- Return type:
tuple
- Raises:
ValueError – If memory is empty or no episodes meet minimum length.
MemoryError – If estimated memory usage exceeds 200 MiB threshold.
- class torchwm.Episode(postprocess_fn=None)[source]#
Bases:
objectRecords the agent’s interaction with the environment for a single episode.
Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.
- x#
Observations collected during the episode.
- Type:
list or np.ndarray
- u#
Actions taken.
- Type:
list or np.ndarray
- r#
Rewards received.
- Type:
list or np.ndarray
- t#
Terminal flags (0.0 = continue, 1.0 = terminal).
- Type:
list or np.ndarray
- info#
Additional episode metadata.
- Type:
dict
- Parameters:
postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.
Example
>>> episode = Episode() >>> episode.append(obs, action, reward, False) >>> episode.append(obs, action, reward, True) >>> episode.terminate(final_obs) >>> print(episode.x.shape) # Now a numpy array
- property size#
- class torchwm.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#
Bases:
objectReplay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.
- Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores uint8 images for memory efficiency
Samples sequences with validation to avoid episode boundaries
Supports sequence sampling for temporal learning
- Memory Layout:
observations: (capacity, C, H, W) uint8
actions: (capacity, action_size) float32
rewards: (capacity,) float32
terminals: (capacity,) float32
- Parameters:
size (int) – Maximum number of transitions to store.
obs_shape (tuple) – Shape of observations as (C, H, W).
action_size (int) – Dimension of actions.
seq_len (int) – Length of sequences to sample (default: 20).
batch_size (int) – Number of sequences per batch (default: 64).
- size#
Buffer capacity.
- Type:
int
- obs_shape#
Observation shape.
- Type:
tuple
- action_size#
Action dimension.
- Type:
int
- seq_len#
Sequence length.
- Type:
int
- batch_size#
Batch size.
- Type:
int
- steps#
Total transitions added.
- Type:
int
- episodes#
Number of episode terminations observed.
- Type:
int
- add(obs, action, reward, terminal)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray) – Observation array with shape (C, H, W).
action (ndarray) – Action array with shape (action_size,).
reward (float) – Scalar reward value.
terminal (bool) – Boolean indicating if episode terminated.
- sample_sequence(seq_len=None)[source]#
Sample a batch of sequences for world model training.
- Returns:
(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)
- Return type:
- Parameters:
seq_len (int | None)
- sample_single()[source]#
Sample a single transition for online updates.
- Return type:
Tuple[ndarray, ndarray, float, float]
- property buffer_capacity#
Returns the total capacity of the buffer.
- class torchwm.IRISOnPolicyBuffer(max_steps=1000)[source]#
Bases:
objectOn-policy buffer for collecting trajectories during environment interaction.
Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.
- Useful for:
Collecting complete episode trajectories
Storing data before batch processing
Temporary storage during environment interaction
- Parameters:
max_steps (int) – Maximum number of steps to store (default: 1000).
- max_steps#
Maximum buffer capacity.
- Type:
int
- observations#
List of observations.
- Type:
list
- actions#
List of actions.
- Type:
list
- rewards#
List of rewards.
- Type:
list
- terminals#
List of terminal flags.
- Type:
list
- class torchwm.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
- class torchwm.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.
- Process:
Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches
Each patch is projected to embed_dim via linear layer (Conv2d)
Learnable positional embeddings are added for spatial information
Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)
- Parameters:
img_size – Image size (assumes square), e.g., 32 for CIFAR
patch_size – Size of each patch (typically 4, 8, or 16)
in_channels – Number of input channels (3 for RGB)
embed_dim – Output dimension for each patch token
- Usage with DiT:
patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4
- class torchwm.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- class torchwm.DDPM(timesteps, beta_start, beta_end, device)[source]#
Bases:
objectUtility class implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- class torchwm.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#
Bases:
ModuleActor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.
- Parameters:
obs_channels (int)
action_dim (int)
channels (Tuple[int, ...])
lstm_dim (int)
- forward(obs, hidden_state=None)[source]#
Forward pass of actor-critic network.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)
- Return type:
policy_logits
- get_action(obs, hidden_state=None, deterministic=False)[source]#
Get action from a single observation.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
deterministic (bool) – If True, take argmax; else sample
- Returns:
Selected action [B] hidden_state: (h, c)
- Return type:
action
- get_actions(obs, hidden_state=None, deterministic=False)[source]#
Batched version of get_action.
- Parameters:
obs (Tensor) – Tensor of shape [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size
deterministic (bool) – If True, take argmax; else sample from policy
- Returns:
LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple
- Return type:
- get_value(obs, hidden_state=None)[source]#
Get value for a single observation.
- Parameters:
obs (Tensor)
hidden_state (Tuple[Tensor, Tensor] | None)
- Return type:
Tuple[Tensor, Tuple[Tensor, Tensor] | None]
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
Get LSTM hidden size.
- Return type:
int
- class torchwm.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#
Bases:
ModuleReward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.
- Parameters:
obs_channels (int) – Number of observation channels (3 for RGB)
action_dim (int) – Number of possible actions
channels (Tuple[int, ...]) – List of channel sizes for conv blocks
lstm_dim (int) – LSTM hidden dimension
cond_dim (int) – Conditioning dimension for adaptive norm
- forward(obs, actions, hidden_state=None)[source]#
Forward pass of reward/termination model.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
actions (Tensor) – Actions [B, T]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states
- Return type:
reward_logits
- predict(obs, actions, hidden_state=None)[source]#
Predict reward and termination for a single step.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
actions (Tensor) – Single action [B]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states
- Return type:
reward
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- torchwm.sinusoidal_time_embedding(timesteps, dim)[source]#
Create sinusoidal timestep embeddings for diffusion conditioning.
This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.
- Math:
embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)
- Parameters:
timesteps – Tensor of integer timesteps, shape (B,) or (B, 1)
dim – Embedding dimension (must be even)
- Returns:
Tensor of shape (B, dim) with sinusoidal embeddings
- Usage with DiT:
t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)
# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization
- class torchwm.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleSpatiotemporal Transformer for video modeling.
Contains L spatiotemporal blocks with interleaved spatial and temporal attention.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
- class torchwm.MultiHeadSelfAttention(d, n_heads=2)[source]#
Bases:
ModuleMulti-head scaled dot-product self-attention over sequence tokens.
This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.
- torchwm.MultiHeadAttention#
alias of
MultiHeadSelfAttention
- class torchwm.AdaLNNormalization(d_model, t_dim)[source]#
Bases:
ModuleAdaptive layer normalization conditioned on an external embedding.
The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).
- class torchwm.RMSNorm(dim, eps=1e-06)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization with a learned gain parameter.
RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.
- class torchwm.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#
Bases:
objectModel-predictive controller using Cross-Entropy Method (CEM) with RSSM.
Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.
- Algorithm:
Initialize Gaussian distribution over action sequences
Sample N candidate action sequences
Rollout each sequence in RSSM latent space
Score by predicted cumulative rewards
Keep top K candidates, fit Gaussian to them
Repeat for T iterations
Execute first action from best sequence
- Why latent space planning?
Images are high-dimensional; latent states are compact
Enables thousands of rollouts in parallel
Dynamics model is more accurate in latent space
- Parameters:
model – RSSM instance for latent dynamics
planning_horizon – Number of future steps to plan (H)
num_candidates – Number of action sequences to sample (N)
num_iterations – CEM refinement iterations (T)
top_candidates – Number of best candidates to keep (K)
device – torch device
- Usage with Planet agent:
- policy = RSSMPolicy(
model=rssm, planning_horizon=12, num_candidates=1000, num_iterations=8, top_candidates=100, device=’cuda’
)
policy.reset() action = policy.poll(observation) # (1, action_dim)
# For continuous control: next_obs, reward, done, info = env.step(action)
- Comparison with Dreamer:
RSSMPolicy: Online planning, chooses actions by optimization at each step
DreamerActor: Train actor network to predict actions from states
Dreamer is more sample-efficient for complex tasks; CEM is more flexible
- class torchwm.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleActor network for IRIS (Imagined Rollouts with Implicit Successor) policy.
Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.
- Architecture:
CNN: Extracts features from input frames (3x64x64 -> 512)
LSTM: Processes temporal sequences with configurable layers
Linear: Maps hidden states to action logits
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
- action_size#
Number of discrete actions.
- Type:
int
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- forward(frames, hidden_state=None, burn_in_frames=None)[source]#
Forward pass through actor.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state
burn_in_frames (Tensor | None) – Frames to use for initializing hidden state
- Returns:
Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple
- Return type:
action_logits
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- get_action(frame, temperature=1.0, deterministic=False)[source]#
Get action from a single frame.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
temperature (float) – Softmax temperature (higher = more random)
deterministic (bool) – If True, return argmax; else sample
- Returns:
Selected action indices (B,)
- Return type:
action
- class torchwm.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCritic network for IRIS value estimation.
Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.
- Architecture:
CNN: Shared feature extractor with actor (3x64x64 -> 512)
LSTM: Temporal processing with same architecture as actor
Linear: Maps hidden states to scalar values
- Parameters:
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- Returns:
Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.
- Return type:
values
- Parameters:
hidden_size (int)
num_layers (int)
frame_shape (Tuple[int, int, int])
- forward(frames, hidden_state=None)[source]#
Forward pass through critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple
- Returns:
Value estimates (B, T) hidden_state: Updated (h, c) tuple
- Return type:
values
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class torchwm.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCombined policy module for IRIS (Imagined Rollouts with Implicit Successor).
Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
Example
>>> policy = IRISPolicy( ... action_size=18, ... hidden_size=512, ... num_layers=4, ... frame_shape=(3, 64, 64) ... ) >>> action = policy.act(frame, temperature=1.0, deterministic=False)
- forward(frames)[source]#
Get action logits from frames.
- Parameters:
frames (Tensor)
- Return type:
Tensor
- act(frame, temperature=1.0, deterministic=False)[source]#
Sample action from policy.
- Parameters:
frame (Tensor)
temperature (float)
deterministic (bool)
- Return type:
Tensor
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- class torchwm.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#
Bases:
ModuleCNN feature extractor shared between actor and critic networks.
Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.
- Architecture:
Conv layers: 32 -> 64 -> 128 -> 256 channels
Each conv has stride=2 for spatial downsampling
Final linear layer projects to desired output dimension
- Parameters:
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
output_size (int) – Size of output feature vector (default: 512).
- frame_shape#
Input frame shape.
- Type:
tuple
- output_size#
Output feature dimension.
- Type:
int
- Returns:
Feature vectors with shape (B, output_size).
- Return type:
features
- Parameters:
frame_shape (Tuple[int, int, int])
output_size (int)
- class torchwm.DreamerConfig[source]#
Bases:
objectConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- class torchwm.JEPAConfig[source]#
Bases:
objectMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- class torchwm.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
objectDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- torchwm.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)
- class torchwm.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
object- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
- class torchwm.IRISConfig[source]#
Bases:
objectConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class torchwm.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
objectConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class torchwm.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#
Bases:
objectSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- class torchwm.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class torchwm.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
objectConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class torchwm.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
objectConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class torchwm.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class torchwm.OperatorABC[source]#
Bases:
ABCAbstract base class for operators that preprocess inputs for inference pipelines.
- class torchwm.DreamerOperator(image_size=64, action_dim=6)[source]#
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- class torchwm.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
- class torchwm.IrisOperator(seq_length=512, vocab_size=32000)[source]#
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- class torchwm.PlaNetOperator(state_dim=32, action_dim=4)[source]#
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- torchwm.get_operator(name, **kwargs)[source]#
Factory function to get inference operators by name.
- Parameters:
name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’
**kwargs – Operator-specific configuration
- Returns:
Configured OperatorABC instance
Example
>>> op = get_operator('dreamer', image_size=64, action_dim=6) >>> processed = op.process({'image': image, 'action': action})
- class torchwm.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModulePredict scalar rewards from Dreamer latent belief and state vectors.
Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.
- class torchwm.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModuleEstimate scalar value from Dreamer latent belief and state vectors.
This MLP is trained on imagined returns and used for actor/value updates.
TorchWM public API.
This package keeps imports lightweight while still exposing a friendly top-level surface. Common workflows can use the small factory helpers:
import torchwm
cfg = torchwm.create_config("dreamer", env="walker-walk")
agent = torchwm.create_model("dreamer", cfg)
env = torchwm.make_env("CartPole-v1", backend="gym")
Lower-level research components remain available as lazy top-level exports, for
example from torchwm import DreamerAgent, ConvEncoder, ReplayBuffer.
- class world_models.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing an environment backend available through
make_env.- Parameters:
name (str)
factory_path (str)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- factory_path: str#
Alias for field number 1
- description: str#
Alias for field number 2
- aliases: tuple[str, ...]#
Alias for field number 3
- class world_models.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing a model available through
create_model().- Parameters:
name (str)
import_path (str)
config_path (str | None)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- import_path: str#
Alias for field number 1
- config_path: str | None#
Alias for field number 2
- description: str#
Alias for field number 3
- aliases: tuple[str, ...]#
Alias for field number 4
- world_models.create_config(model, **overrides)[source]#
Create the default config object for
modeland apply overrides.Examples
>>> cfg = create_config("dreamer", env="walker-walk", seed=7) >>> cfg.env 'walker-walk'
- Parameters:
model (str)
overrides (Any)
- Return type:
Any
- world_models.create_model(model, config=None, **overrides)[source]#
Instantiate a model or agent from a simple string name.
configis optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.Examples
>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000) >>> genie = create_model("genie-small", image_size=32)
- Parameters:
model (str)
config (Any | None)
overrides (Any)
- Return type:
Any
- world_models.get_env_backend_spec(name)[source]#
Return metadata for an environment backend name or alias.
- Parameters:
name (str)
- Return type:
- world_models.get_model_spec(name)[source]#
Return metadata for a model name or alias.
- Parameters:
name (str)
- Return type:
- world_models.list_env_backends()[source]#
Return canonical backend names accepted by
make_env().- Return type:
list[str]
- world_models.list_envs(model=None)[source]#
List known environment ids, optionally filtered by model family.
- Parameters:
model (str | None)
- Return type:
list[str] | dict[str, list[str]]
- world_models.list_models()[source]#
Return canonical model names accepted by
create_model().- Return type:
list[str]
- world_models.make_env(env_id, backend='auto', **kwargs)[source]#
Create an environment with a consistent TorchWM entry point.
- Parameters:
env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.
backend (str) – One of
list_env_backends();"auto"tries TorchWM’s compatibility helper.**kwargs (Any) – Backend-specific options.
- Return type:
Any
- class world_models.IRISAgent(config, action_size, device)[source]#
Bases:
ModuleComplete IRIS Agent with world model and policy.
Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning
- Parameters:
config (IRISConfig)
action_size (int)
device (device)
- forward_actor_critic(frames, hidden=None)[source]#
Forward pass through actor-critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state
- Returns:
(B, T, action_size) values: (B, T) hidden_state: (h, c)
- Return type:
action_logits
- act(frame, epsilon=0.0, temperature=1.0)[source]#
Sample action from policy.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
epsilon (float) – Random action probability
temperature (float) – Action distribution temperature
- Returns:
Selected actions (B,)
- Return type:
- imagine_rollout(initial_frame, horizon=20)[source]#
Generate imagined trajectories using world model.
- Parameters:
initial_frame (Tensor) – Starting frame (B, C, H, W)
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined rollout data
- Return type:
trajectory
- update_autoencoder(frames)[source]#
Update discrete autoencoder.
- Parameters:
frames (Tensor) – Training frames (B, C, H, W)
- Returns:
Dictionary of loss values
- Return type:
losses
- update_transformer(frames, actions, rewards, terminals)[source]#
Update transformer world model.
- Parameters:
frames (Tensor) – Frame sequence
actions (Tensor) – Actions taken
rewards (Tensor) – Rewards received
terminals (Tensor) – Terminal flags
- Returns:
Dictionary of loss values
- Return type:
losses
- world_models.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#
Compute λ-return target for value function training.
- Parameters:
rewards (Tensor) – Rewards (B, T)
values (Tensor) – Value estimates (B, T+1)
discounts (Tensor) – Discount factors (B, T)
lambda_coef (float) – Lambda parameter for bootstrapping
- Returns:
λ-return targets (B, T)
- Return type:
- class world_models.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#
Bases:
ModuleModular RSSM with swappable encoder, decoder, and backbone.
This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer
Example
>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024) >>> decoder = ConvDecoder(32, 200, (3, 64, 64)) >>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024) >>> rssm = ModularRSSM(encoder, decoder, backbone)
- Parameters:
encoder (EncoderBase)
decoder (DecoderBase)
backbone (BackboneBase)
reward_decoder (DecoderBase | None)
- property stoch_size: int#
- property deter_size: int#
- property embed_size: int#
- init_state(batch_size, device)[source]#
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
obs (Tensor)
nonterm (Any)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
nonterm (Any)
- Return type:
Dict[str, Tensor]
- observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
- Parameters:
obs (Tensor)
actions (Tensor)
nonterms (Tensor)
prev_state (Dict[str, Tensor])
horizon (int)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- world_models.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#
Factory function to create a modular RSSM with specified components.
- Parameters:
encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)
decoder_type (str) – Type of decoder (“conv”, “mlp”)
backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)
obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state
action_size (int) – Action space dimension
stoch_size (int) – Stochastic latent dimension
deter_size (int) – Deterministic hidden dimension
embed_size (int) – Encoder embedding dimension
hidden_size (int) – Hidden layer dimension
activation (str) – Activation function name
- Returns:
Configured ModularRSSM instance
- Return type:
- class world_models.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleGenie: Generative Interactive Environment.
A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions
Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).
Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)
The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse
At inference, latent actions are stopgrad’d when passed to dynamics model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_decoder_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
encoder_depth (int)
decoder_depth (int)
latent_action_depth (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- forward(video, mask_prob=0.5, training_phase='all')[source]#
Full forward pass through all components.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing losses and predictions
- Return type:
Dict[str, Tensor]
- training_step(video, mask_prob=0.5, training_phase='all')[source]#
Single training step computing all losses.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing all losses for backpropagation
- Return type:
Dict[str, Tensor]
- encode_video(video)[source]#
Encode video to discrete tokens.
- Parameters:
video (Tensor) – (B, C, T, H, W)
- Returns:
(B, T, H*W)
- Return type:
video_tokens
- infer_actions(frames)[source]#
Infer latent actions from a sequence of frames.
- Parameters:
frames (Tensor) – (B, C, T, H, W) video frames
- Returns:
(B, T-1) inferred latent action indices
- Return type:
latent_actions
- generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#
Generate video frames given a prompt frame and actions.
- Parameters:
prompt_frame (Tensor) – (B, C, H, W) initial frame
num_frames (int) – Total number of frames to generate
actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random
use_maskgit (bool) – Whether to use MaskGIT sampling
- Returns:
(B, C, num_frames, H, W)
- Return type:
generated_video
- play(current_frame, action, current_frames=None)[source]#
Play step - generate next frame given current frame and action.
- Parameters:
current_frame (Tensor) – (B, C, H, W) current frame
action (Tensor) – (B,) latent action indices
current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame
- Returns:
(B, C, H, W)
- Return type:
next_frame
- class world_models.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleLatent Action Model (LAM) for unsupervised action learning.
Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.
Based on Genie paper - learns actions without action labels from Internet videos.
Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- encode(x_prev, x_next)[source]#
Encode frames to latent actions.
- Parameters:
x_prev (Tensor) – Previous frames (B, C, T, H, W)
x_next (Tensor) – Next frame (B, C, H, W)
- Returns:
Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)
- Return type:
latent_actions
- class world_models.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
ModuleDynamics Model for action-controllable video generation.
A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.
Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- forward(video_tokens, actions, mask_prob=0.0)[source]#
Forward pass for training.
- Parameters:
video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T
actions (Tensor) – (B, T) - latent action indices for frames 1 to T
mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)
- Returns:
(B, T, H*W, vocab_size)
- Return type:
logits
- sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#
Sample future frames using MaskGIT.
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
sampler (MaskGITSampler | None) – MaskGIT sampler instance
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#
Simple autoregressive sampling (token by token).
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
temperature (float) – Sampling temperature
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- world_models.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Create a smaller Genie model for development/testing.
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#
Create the full 11B parameter Genie model (approximate).
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#
Factory function to create a Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
- Return type:
- class world_models.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#
Bases:
ModuleRecurrent State-Space Model used by Dreamer for latent dynamics learning.
The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:
Deterministic State (h): A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.
Stochastic State (s): A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).
The model operates in two modes:
Observe Mode: Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)
Imagine Mode: Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)
- Architecture:
Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1} Process: GRU updates deterministic state, MLP computes stochastic prior/posterior Output: Updated state (h_t, s_t) and distributions
- State Representation:
deter (h): GRU hidden state, captures sequential context
stoch (s): Stochastic latent, multi-modal uncertainty
mean/std: Parameters of the stochastic distribution
- Usage with DreamerAgent:
- rssm = RSSM(
action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation=’elu’
)
# Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed)
# Imagine future without observation prior = rssm.imagine_step(current_state, action)
- Training:
The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to
capture environment dynamics
Reconstruction loss from decoder ensures state captures observation info
- Reference:
Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603
- init_state(batch_size, device)[source]#
Initialize RSSM state with zeros.
- Parameters:
batch_size – Number of parallel sequences
device – torch device for tensors
- Returns:
mean, std: Stochastic distribution parameters
stoch: Stochastic state sample
deter: Deterministic GRU hidden state
- Return type:
Dictionary containing zero-initialized state components
- get_dist(mean, std)[source]#
Create an Independent Normal distribution from mean and std.
- Parameters:
mean – Location parameter
std – Scale parameter
- Returns:
Independent Normal distribution with given parameters
- observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
Update state using actual observation (observe mode).
In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.
- Parameters:
prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action – Previous action a_{t-1}, shape (B, action_size)
obs_embed – Observation embedding from encoder, shape (B, obs_embed_size)
nonterm – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
A tuple
(posterior, prior)of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Predict next state without observation (imagine mode).
In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.
- Parameters:
prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action – Previous action a_{t-1}, shape (B, action_size)
nonterm – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
deter: Predicted deterministic state
mean, std, stoch: Prior stochastic state distribution
- Return type:
Dictionary with predicted state containing
- get_prior(prev_state, prev_action, nonterm=1.0)[source]#
Compute prior distribution over stochastic state.
The prior represents the model’s belief about the stochastic state before observing the actual outcome.
- Parameters:
prev_state – Previous state dictionary
prev_action – Previous action
nonterm – Termination mask
- Returns:
Dictionary with prior state (no observation information)
- get_posterior(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
Compute posterior distribution over stochastic state.
The posterior incorporates observation information to produce a more accurate state estimate.
- Parameters:
prev_state – Previous state dictionary
prev_action – Previous action
obs_embed – Observation embedding
nonterm – Termination mask
- Returns:
Dictionary with posterior state (observation-informed). Note that the previous-state shape
(B, ...)is preserved; the batch dimension is not flattened.
- detach_state(state)[source]#
Detach state tensors from computation graph.
Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.
- Parameters:
state – State dictionary with tensor values
- Returns:
Detached state dictionary
- seq_to_batch(state_dict)[source]#
Convert sequence state to batch format.
- Parameters:
state_dict – Dictionary with sequence-dimension tensors (T, B, …)
- Returns:
Dictionary with batch-dimension tensors (B*T, …)
- observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#
Process a sequence of observations (observe mode rollout).
At each timestep we run
observe_steponce to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.- Parameters:
obs_embed – Observation embeddings, shape (T+1, B, obs_embed_size)
actions – Actions, shape (T, B, action_size)
nonterms – Non-termination flags, shape (T, B, 1)
init_state – Initial state dictionary
seq_len – Sequence length T
- Returns:
Dictionary with prior states stacked along the time axis. posterior: Dictionary with posterior states stacked along the time
axis.
- Return type:
prior
- imagine_rollout(policy, init_state, horizon)[source]#
Generate imagined trajectory using policy (imagine mode rollout).
- Parameters:
policy – Actor network that outputs actions from state features
init_state – Initial state dictionary
horizon – Number of steps to imagine
- Returns:
Dictionary with imagined states for each step
- forward(x, u)[source]#
Forward pass for training (computes sequence of states).
- Parameters:
x – Observations, shape (B, T+1, C, H, W)
u – Actions, shape (B, T, action_size)
- Returns:
List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)
- Return type:
states
- class world_models.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#
Bases:
ModuleA Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.
- get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#
Returns the initial posterior given the observation.
- deterministic_state_fwd(h_t, s_t, a_t)[source]#
- Deterministic transition update that accepts:
a_t shaped [B, action_size]
a_t shaped [action_size] (unbatched) -> expanded to [B, action_size]
a_t shaped [B] or scalar -> reshaped appropriately
Ensures a_t is 2D and matches batch dimension of h_t before concatenation.
- state_prior(h_t, sample=False)[source]#
Returns the prior distribution over the latent state given the deterministic state
- state_posterior(h_t, e_t, sample=False)[source]#
Returns the state prior given the deterministic state and obs
- forward(x, u)[source]#
Forward through the RSSM for a batch of sequences. Inputs:
x: Tensor [B, T+1, C, H, W] (observations including initial frame) u: Tensor [B, T, action_size] (actions for T steps)
- Returns:
list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]
- Return type:
states
- class world_models.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).
- Architecture:
Input: (B, C, H, W) raw images, values in [-0.5, 0.5] Process: 4 convolutional layers with stride 2, halving spatial dimensions Output: (B, embed_size) compact representation
The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.
- Usage with Dreamer:
- encoder = ConvEncoder(
input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation=’relu’ # Activation function
) obs_embedding = encoder(observation) # (B, 256)
- Parameters:
input_shape – Tuple (C, H, W) for input images, typically (3, 64, 64)
embed_size – Output embedding dimension, typically 256 or 1024
activation – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)
depth – Base channel depth for first layer (default 32)
- class world_models.CNNEncoder(embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) encoder for processing image inputs.
- class world_models.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#
Bases:
ModuleConvolutional decoder for reconstructing observations from latent states.
Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.
- Architecture:
Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter) Process: Dense projection + 4 transposed convolutions (upsampling 2x each) Output: Independent Normal distribution over observation pixels
The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.
- Output Distribution:
Returns torch.distributions.Independent(Normal(mean, std), len(shape)) This allows computing log_prob(observation) for reconstruction loss.
- Usage in Dreamer world model:
- decoder = ConvDecoder(
stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation=’relu’
) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)
- Training:
The reconstruction loss is: -log_prob(observation) This encourages the RSSM to learn states that capture observation information.
- class world_models.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) decoder for reconstructing image outputs.
- class world_models.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.
- Architecture:
Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter) Process: MLP with configurable layers and hidden units Output: Predicted quantity with distribution (normal, binary, or raw)
- Supports three output types:
‘normal’: Gaussian distribution for regression (rewards, values)
‘binary’: Bernoulli distribution for binary classification (discount)
‘none’: Raw tensor for non-probabilistic outputs
- Usage:
- reward_decoder = DenseDecoder(
stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’normal’
) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)
- For discount prediction (binary):
- discount_decoder = DenseDecoder(
stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’binary’ # Bernoulli for P(continue)
)
- class world_models.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
- class world_models.TanhBijector[source]#
Bases:
TransformBijective tanh transform for squashing Gaussian distributions to [-1, 1].
This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:
Bijective mapping: tanh is invertible (with atanh as inverse)
Stable log-det Jacobian: Computable for gradient-based training
Clipped actions: During inference, actions are naturally bounded
- Math:
Forward: y = tanh(x) Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y)) Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))
- Usage with Dreamer ActionDecoder:
- dist = TransformedDistribution(
Normal(mean, std), TanhBijector()
) action = dist.sample() # Bounded to [-1, 1]
- Reference:
Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.
- property sign#
- class world_models.SampleDist(dist, samples=100)[source]#
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- property name#
- class world_models.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Encoder for IRIS discrete autoencoder.
Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.
- Architecture:
4 convolutional layers with residual blocks
Self-attention at 8x8 and 16x16 resolutions
Vector quantization to produce discrete tokens
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
in_channels (int)
base_channels (int)
num_residual_blocks (int)
frame_shape (Tuple[int, int, int])
- forward(x)[source]#
Encode images to discrete tokens.
- Parameters:
x (Tensor) – Input images (B, C, H, W) - should be 64x64
- Returns:
Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- class world_models.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Decoder for IRIS discrete autoencoder.
Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.
- Parameters:
vocab_size (int)
embedding_dim (int)
base_channels (int)
out_channels (int)
frame_shape (Tuple[int, int, int])
- forward(z)[source]#
Decode tokens to images.
- Parameters:
z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)
- Returns:
Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)
- Return type:
reconstructed
- class world_models.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#
Bases:
ModuleVideo Tokenizer using VQ-VAE with Spatiotemporal Transformer.
This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.
The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.
- Architecture:
Patch Embedding: Convert (B, C, T, H, W) video to patch tokens
Encoder ST-Transformer: Process spatial-temporal patches
Vector Quantization: Discretize continuous embeddings to codebook entries
Decoder ST-Transformer: Reconstruct video from quantized tokens
Patch Unembedding: Convert tokens back to video frames
- Key Features:
Causal processing: Each frame’s encoding only uses previous frames
Discrete tokens: Enables autoregressive prediction with latent actions
Memory efficient: Uses ST-Transformer instead of full ViT to reduce O(n²) complexity
- Usage with Genie:
- tokenizer = VideoTokenizer(
num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32
) reconstructed, indices, loss_dict = tokenizer(video_frames)
# For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)
- Training:
The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings (encourages learning useful codes) - Commitment loss: Penalizes encoder outputs drifting from codebook
- Reference:
Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
use_ema (bool)
ema_decay (float)
- encode(x)[source]#
Encode video to discrete tokens.
- Parameters:
x (Tensor) – Video tensor (B, C, T, H, W)
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- decode_indices(indices)[source]#
Decode token indices to embeddings for video frames.
- Parameters:
indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’*W’
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim)
- Return type:
z_q
- world_models.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#
Factory function to create a Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
- Return type:
- class world_models.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#
Bases:
ModuleVector Quantizer for discrete autoencoder.
Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)
Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
- class world_models.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#
Bases:
ModuleVector Quantizer with Exponential Moving Average updates.
Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
ema_decay (float)
epsilon (float)
- class world_models.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#
Bases:
objectFixed-size replay buffer for Dreamer with image observations and transitions.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.
- Key Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores raw uint8 images to save memory
Samples sequences (not single transitions) for temporal modeling
Validates sampled sequences don’t span episode boundaries
- Memory Layout:
observations: (capacity, C, H, W) uint8 images
actions: (capacity, action_dim) float32
rewards: (capacity,) float32
terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)
- Sampling Process:
Random start index (avoiding episode boundaries)
Collect sequence of length seq_len with wraparound
Validate no terminal in middle of sequence
Return batch of sequences
- Usage with Dreamer:
- buffer = ReplayBuffer(
size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch
)
# Add transitions during interaction buffer.add(obs, action, reward, done)
# Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample() # Shapes: (seq_len, batch, C, H, W), (seq_len, batch, action_dim), etc.
- Memory Efficiency:
Uses uint8 for images (1 byte per pixel vs 4 for float32)
Sequences share observations (overlapping windows)
Configurable capacity based on available system memory
Note
The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.
- Parameters:
size (int)
obs_shape (Tuple[int, ...])
action_size (int)
seq_len (int)
batch_size (int)
- add(obs, ac, rew, done)[source]#
Add a transition to the buffer.
- Parameters:
obs (dict) – Observation dict with ‘image’ key containing the observation
ac (ndarray) – Action taken, shape (action_size,)
rew (float) – Reward received, scalar
done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise
- Return type:
None
- class world_models.Memory(size=None)[source]#
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.
- Features:
Stores complete episodes as lists of transitions
Samples contiguous sub-sequences for sequence models
Supports time-major formatting (time-first) for RNN input
Memory usage estimation to prevent OOM errors
- Parameters:
size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).
- episodes#
Collection of Episode objects.
- Type:
deque
- eps_lengths#
Length of each episode.
- Type:
deque
- size#
Total number of transitions across all episodes.
- Type:
property
Example
>>> memory = Memory(size=100) >>> memory.append([episode1, episode2]) >>> batch, lengths = memory.sample(batch_size=32, tracelen=50)
- property size#
- sample(batch_size, tracelen=1, time_first=False)[source]#
Sample random sub-sequences from stored episodes.
Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.
- Parameters:
batch_size (int) – Number of sequences to sample.
tracelen (int) – Length of each sequence (default: 1).
time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).
- Returns:
- (observations, actions, rewards, terminals, lengths)
observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)
actions: (batch, tracelen, action_dim) or (tracelen, batch, …)
rewards: (batch, tracelen) or (tracelen, batch)
terminals: (batch, tracelen) or (tracelen, batch)
lengths: (batch,) original episode lengths for each sample
- Return type:
tuple
- Raises:
ValueError – If memory is empty or no episodes meet minimum length.
MemoryError – If estimated memory usage exceeds 200 MiB threshold.
- class world_models.Episode(postprocess_fn=None)[source]#
Bases:
objectRecords the agent’s interaction with the environment for a single episode.
Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.
- x#
Observations collected during the episode.
- Type:
list or np.ndarray
- u#
Actions taken.
- Type:
list or np.ndarray
- r#
Rewards received.
- Type:
list or np.ndarray
- t#
Terminal flags (0.0 = continue, 1.0 = terminal).
- Type:
list or np.ndarray
- info#
Additional episode metadata.
- Type:
dict
- Parameters:
postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.
Example
>>> episode = Episode() >>> episode.append(obs, action, reward, False) >>> episode.append(obs, action, reward, True) >>> episode.terminate(final_obs) >>> print(episode.x.shape) # Now a numpy array
- property size#
- class world_models.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#
Bases:
objectReplay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.
- Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores uint8 images for memory efficiency
Samples sequences with validation to avoid episode boundaries
Supports sequence sampling for temporal learning
- Memory Layout:
observations: (capacity, C, H, W) uint8
actions: (capacity, action_size) float32
rewards: (capacity,) float32
terminals: (capacity,) float32
- Parameters:
size (int) – Maximum number of transitions to store.
obs_shape (tuple) – Shape of observations as (C, H, W).
action_size (int) – Dimension of actions.
seq_len (int) – Length of sequences to sample (default: 20).
batch_size (int) – Number of sequences per batch (default: 64).
- size#
Buffer capacity.
- Type:
int
- obs_shape#
Observation shape.
- Type:
tuple
- action_size#
Action dimension.
- Type:
int
- seq_len#
Sequence length.
- Type:
int
- batch_size#
Batch size.
- Type:
int
- steps#
Total transitions added.
- Type:
int
- episodes#
Number of episode terminations observed.
- Type:
int
- add(obs, action, reward, terminal)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray) – Observation array with shape (C, H, W).
action (ndarray) – Action array with shape (action_size,).
reward (float) – Scalar reward value.
terminal (bool) – Boolean indicating if episode terminated.
- sample_sequence(seq_len=None)[source]#
Sample a batch of sequences for world model training.
- Returns:
(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)
- Return type:
- Parameters:
seq_len (int | None)
- sample_single()[source]#
Sample a single transition for online updates.
- Return type:
Tuple[ndarray, ndarray, float, float]
- property buffer_capacity#
Returns the total capacity of the buffer.
- class world_models.IRISOnPolicyBuffer(max_steps=1000)[source]#
Bases:
objectOn-policy buffer for collecting trajectories during environment interaction.
Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.
- Useful for:
Collecting complete episode trajectories
Storing data before batch processing
Temporary storage during environment interaction
- Parameters:
max_steps (int) – Maximum number of steps to store (default: 1000).
- max_steps#
Maximum buffer capacity.
- Type:
int
- observations#
List of observations.
- Type:
list
- actions#
List of actions.
- Type:
list
- rewards#
List of rewards.
- Type:
list
- terminals#
List of terminal flags.
- Type:
list
- class world_models.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
- class world_models.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.
- Process:
Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches
Each patch is projected to embed_dim via linear layer (Conv2d)
Learnable positional embeddings are added for spatial information
Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)
- Parameters:
img_size – Image size (assumes square), e.g., 32 for CIFAR
patch_size – Size of each patch (typically 4, 8, or 16)
in_channels – Number of input channels (3 for RGB)
embed_dim – Output dimension for each patch token
- Usage with DiT:
patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4
- class world_models.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- class world_models.DDPM(timesteps, beta_start, beta_end, device)[source]#
Bases:
objectUtility class implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- class world_models.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#
Bases:
ModuleActor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.
- Parameters:
obs_channels (int)
action_dim (int)
channels (Tuple[int, ...])
lstm_dim (int)
- forward(obs, hidden_state=None)[source]#
Forward pass of actor-critic network.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)
- Return type:
policy_logits
- get_action(obs, hidden_state=None, deterministic=False)[source]#
Get action from a single observation.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
deterministic (bool) – If True, take argmax; else sample
- Returns:
Selected action [B] hidden_state: (h, c)
- Return type:
action
- get_actions(obs, hidden_state=None, deterministic=False)[source]#
Batched version of get_action.
- Parameters:
obs (Tensor) – Tensor of shape [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size
deterministic (bool) – If True, take argmax; else sample from policy
- Returns:
LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple
- Return type:
- get_value(obs, hidden_state=None)[source]#
Get value for a single observation.
- Parameters:
obs (Tensor)
hidden_state (Tuple[Tensor, Tensor] | None)
- Return type:
Tuple[Tensor, Tuple[Tensor, Tensor] | None]
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
Get LSTM hidden size.
- Return type:
int
- class world_models.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#
Bases:
ModuleReward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.
- Parameters:
obs_channels (int) – Number of observation channels (3 for RGB)
action_dim (int) – Number of possible actions
channels (Tuple[int, ...]) – List of channel sizes for conv blocks
lstm_dim (int) – LSTM hidden dimension
cond_dim (int) – Conditioning dimension for adaptive norm
- forward(obs, actions, hidden_state=None)[source]#
Forward pass of reward/termination model.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
actions (Tensor) – Actions [B, T]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states
- Return type:
reward_logits
- predict(obs, actions, hidden_state=None)[source]#
Predict reward and termination for a single step.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
actions (Tensor) – Single action [B]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states
- Return type:
reward
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- world_models.sinusoidal_time_embedding(timesteps, dim)[source]#
Create sinusoidal timestep embeddings for diffusion conditioning.
This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.
- Math:
embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)
- Parameters:
timesteps – Tensor of integer timesteps, shape (B,) or (B, 1)
dim – Embedding dimension (must be even)
- Returns:
Tensor of shape (B, dim) with sinusoidal embeddings
- Usage with DiT:
t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)
# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization
- class world_models.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleSpatiotemporal Transformer for video modeling.
Contains L spatiotemporal blocks with interleaved spatial and temporal attention.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
- class world_models.MultiHeadSelfAttention(d, n_heads=2)[source]#
Bases:
ModuleMulti-head scaled dot-product self-attention over sequence tokens.
This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.
- world_models.MultiHeadAttention#
alias of
MultiHeadSelfAttention
- class world_models.AdaLNNormalization(d_model, t_dim)[source]#
Bases:
ModuleAdaptive layer normalization conditioned on an external embedding.
The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).
- class world_models.RMSNorm(dim, eps=1e-06)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization with a learned gain parameter.
RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.
- class world_models.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#
Bases:
objectModel-predictive controller using Cross-Entropy Method (CEM) with RSSM.
Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.
- Algorithm:
Initialize Gaussian distribution over action sequences
Sample N candidate action sequences
Rollout each sequence in RSSM latent space
Score by predicted cumulative rewards
Keep top K candidates, fit Gaussian to them
Repeat for T iterations
Execute first action from best sequence
- Why latent space planning?
Images are high-dimensional; latent states are compact
Enables thousands of rollouts in parallel
Dynamics model is more accurate in latent space
- Parameters:
model – RSSM instance for latent dynamics
planning_horizon – Number of future steps to plan (H)
num_candidates – Number of action sequences to sample (N)
num_iterations – CEM refinement iterations (T)
top_candidates – Number of best candidates to keep (K)
device – torch device
- Usage with Planet agent:
- policy = RSSMPolicy(
model=rssm, planning_horizon=12, num_candidates=1000, num_iterations=8, top_candidates=100, device=’cuda’
)
policy.reset() action = policy.poll(observation) # (1, action_dim)
# For continuous control: next_obs, reward, done, info = env.step(action)
- Comparison with Dreamer:
RSSMPolicy: Online planning, chooses actions by optimization at each step
DreamerActor: Train actor network to predict actions from states
Dreamer is more sample-efficient for complex tasks; CEM is more flexible
- class world_models.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleActor network for IRIS (Imagined Rollouts with Implicit Successor) policy.
Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.
- Architecture:
CNN: Extracts features from input frames (3x64x64 -> 512)
LSTM: Processes temporal sequences with configurable layers
Linear: Maps hidden states to action logits
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
- action_size#
Number of discrete actions.
- Type:
int
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- forward(frames, hidden_state=None, burn_in_frames=None)[source]#
Forward pass through actor.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state
burn_in_frames (Tensor | None) – Frames to use for initializing hidden state
- Returns:
Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple
- Return type:
action_logits
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- get_action(frame, temperature=1.0, deterministic=False)[source]#
Get action from a single frame.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
temperature (float) – Softmax temperature (higher = more random)
deterministic (bool) – If True, return argmax; else sample
- Returns:
Selected action indices (B,)
- Return type:
action
- class world_models.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCritic network for IRIS value estimation.
Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.
- Architecture:
CNN: Shared feature extractor with actor (3x64x64 -> 512)
LSTM: Temporal processing with same architecture as actor
Linear: Maps hidden states to scalar values
- Parameters:
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- Returns:
Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.
- Return type:
values
- Parameters:
hidden_size (int)
num_layers (int)
frame_shape (Tuple[int, int, int])
- forward(frames, hidden_state=None)[source]#
Forward pass through critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple
- Returns:
Value estimates (B, T) hidden_state: Updated (h, c) tuple
- Return type:
values
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class world_models.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCombined policy module for IRIS (Imagined Rollouts with Implicit Successor).
Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
Example
>>> policy = IRISPolicy( ... action_size=18, ... hidden_size=512, ... num_layers=4, ... frame_shape=(3, 64, 64) ... ) >>> action = policy.act(frame, temperature=1.0, deterministic=False)
- forward(frames)[source]#
Get action logits from frames.
- Parameters:
frames (Tensor)
- Return type:
Tensor
- act(frame, temperature=1.0, deterministic=False)[source]#
Sample action from policy.
- Parameters:
frame (Tensor)
temperature (float)
deterministic (bool)
- Return type:
Tensor
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- class world_models.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#
Bases:
ModuleCNN feature extractor shared between actor and critic networks.
Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.
- Architecture:
Conv layers: 32 -> 64 -> 128 -> 256 channels
Each conv has stride=2 for spatial downsampling
Final linear layer projects to desired output dimension
- Parameters:
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
output_size (int) – Size of output feature vector (default: 512).
- frame_shape#
Input frame shape.
- Type:
tuple
- output_size#
Output feature dimension.
- Type:
int
- Returns:
Feature vectors with shape (B, output_size).
- Return type:
features
- Parameters:
frame_shape (Tuple[int, int, int])
output_size (int)
- class world_models.DreamerConfig[source]#
Bases:
objectConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- class world_models.JEPAConfig[source]#
Bases:
objectMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- class world_models.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
objectDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- world_models.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)
- class world_models.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
object- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
- class world_models.IRISConfig[source]#
Bases:
objectConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class world_models.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
objectConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#
Bases:
objectSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- class world_models.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
objectConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class world_models.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
objectConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class world_models.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.OperatorABC[source]#
Bases:
ABCAbstract base class for operators that preprocess inputs for inference pipelines.
- class world_models.DreamerOperator(image_size=64, action_dim=6)[source]#
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- class world_models.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
- class world_models.IrisOperator(seq_length=512, vocab_size=32000)[source]#
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- class world_models.PlaNetOperator(state_dim=32, action_dim=4)[source]#
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- world_models.get_operator(name, **kwargs)[source]#
Factory function to get inference operators by name.
- Parameters:
name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’
**kwargs – Operator-specific configuration
- Returns:
Configured OperatorABC instance
Example
>>> op = get_operator('dreamer', image_size=64, action_dim=6) >>> processed = op.process({'image': image, 'action': action})
- class world_models.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModulePredict scalar rewards from Dreamer latent belief and state vectors.
Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.
- class world_models.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModuleEstimate scalar value from Dreamer latent belief and state vectors.
This MLP is trained on imagined returns and used for actor/value updates.
User-facing convenience APIs for TorchWM.
The lower-level modules remain available for research workflows, but this module collects the common discovery and construction paths behind small, predictable factory functions.
- class world_models.api.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing an environment backend available through
make_env.- Parameters:
name (str)
factory_path (str)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- factory_path: str#
Alias for field number 1
- description: str#
Alias for field number 2
- aliases: tuple[str, ...]#
Alias for field number 3
- class world_models.api.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing a model available through
create_model().- Parameters:
name (str)
import_path (str)
config_path (str | None)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- import_path: str#
Alias for field number 1
- config_path: str | None#
Alias for field number 2
- description: str#
Alias for field number 3
- aliases: tuple[str, ...]#
Alias for field number 4
- world_models.api.create_config(model, **overrides)[source]#
Create the default config object for
modeland apply overrides.Examples
>>> cfg = create_config("dreamer", env="walker-walk", seed=7) >>> cfg.env 'walker-walk'
- Parameters:
model (str)
overrides (Any)
- Return type:
Any
- world_models.api.create_model(model, config=None, **overrides)[source]#
Instantiate a model or agent from a simple string name.
configis optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.Examples
>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000) >>> genie = create_model("genie-small", image_size=32)
- Parameters:
model (str)
config (Any | None)
overrides (Any)
- Return type:
Any
- world_models.api.get_env_backend_spec(name)[source]#
Return metadata for an environment backend name or alias.
- Parameters:
name (str)
- Return type:
- world_models.api.get_model_spec(name)[source]#
Return metadata for a model name or alias.
- Parameters:
name (str)
- Return type:
- world_models.api.list_env_backends()[source]#
Return canonical backend names accepted by
make_env().- Return type:
list[str]
- world_models.api.list_envs(model=None)[source]#
List known environment ids, optionally filtered by model family.
- Parameters:
model (str | None)
- Return type:
list[str] | dict[str, list[str]]
- world_models.api.list_models()[source]#
Return canonical model names accepted by
create_model().- Return type:
list[str]
- world_models.api.make_env(env_id, backend='auto', **kwargs)[source]#
Create an environment with a consistent TorchWM entry point.
- Parameters:
env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.
backend (str) – One of
list_env_backends();"auto"tries TorchWM’s compatibility helper.**kwargs (Any) – Backend-specific options.
- Return type:
Any
Models sub-module - Core world model implementations.
- Exported Components:
- Agents (High-level training wrappers):
DreamerAgent: High-level Dreamer training API
JEPAAgent: JEPA agent for self-supervised learning
Planet: PlaNet planning agent
VisionTransformer: Vision Transformer for image encoding
ModularRSSM: Modular RSSM with swappable components
Genie: Generative Interactive Environment model
- Core Models:
Dreamer: Core Dreamer implementation with RSSM, actor, critic
RSSM: Recurrent State-Space Model (Dreamer-style)
RecurrentStateSpaceModel: PlaNet-style RSSM
LatentActionModel: Latent action learning for Genie
DynamicsModel: Future frame prediction for Genie
- Factory Functions:
create_genie, create_genie_small, create_genie_large
create_modular_rssm
create_latent_action_model, create_dynamics_model
Model catalog#
Core model families#
Key classes: Dreamer, DreamerAgent, RSSM, RecurrentStateSpaceModel, Planet, ModularRSSM, JEPAAgent, VisionTransformer, IRISAgent, IRISTransformer, IRISWorldModel, Genie, LatentActionModel, and DynamicsModel.
- world_models.models.dreamer.make_env(args)[source]#
Construct a Dreamer-compatible environment from DreamerConfig options.
Supports DMC, Gym/Gymnasium, MuJoCo, Gymnasium Robotics, Brax, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit.
- world_models.models.dreamer.preprocess_obs(obs)[source]#
Convert raw uint8 image observations to Dreamer float input space.
Images are scaled from [0, 255] to roughly [-0.5, 0.5], matching the normalization expected by Dreamer encoders.
- class world_models.models.dreamer.Dreamer(args, obs_shape, action_size, device, restore=False)[source]#
Bases:
objectCore Dreamer training system combining world model, actor, and value nets.
This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O.
- class world_models.models.dreamer.DreamerAgent(config=None, **kwargs)[source]#
Bases:
objectHigh-level user API for running Dreamer experiments end to end.
It builds environments from config, initializes seeds and logging, instantiates Dreamer, and exposes simple train() / evaluate() methods.
- class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#
Bases:
ModuleRecurrent State-Space Model used by Dreamer for latent dynamics learning.
The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:
Deterministic State (h): A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.
Stochastic State (s): A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).
The model operates in two modes:
Observe Mode: Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)
Imagine Mode: Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)
- Architecture:
Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1} Process: GRU updates deterministic state, MLP computes stochastic prior/posterior Output: Updated state (h_t, s_t) and distributions
- State Representation:
deter (h): GRU hidden state, captures sequential context
stoch (s): Stochastic latent, multi-modal uncertainty
mean/std: Parameters of the stochastic distribution
- Usage with DreamerAgent:
- rssm = RSSM(
action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation=’elu’
)
# Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed)
# Imagine future without observation prior = rssm.imagine_step(current_state, action)
- Training:
The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to
capture environment dynamics
Reconstruction loss from decoder ensures state captures observation info
- Reference:
Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603
- init_state(batch_size, device)[source]#
Initialize RSSM state with zeros.
- Parameters:
batch_size – Number of parallel sequences
device – torch device for tensors
- Returns:
mean, std: Stochastic distribution parameters
stoch: Stochastic state sample
deter: Deterministic GRU hidden state
- Return type:
Dictionary containing zero-initialized state components
- get_dist(mean, std)[source]#
Create an Independent Normal distribution from mean and std.
- Parameters:
mean – Location parameter
std – Scale parameter
- Returns:
Independent Normal distribution with given parameters
- observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
Update state using actual observation (observe mode).
In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.
- Parameters:
prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action – Previous action a_{t-1}, shape (B, action_size)
obs_embed – Observation embedding from encoder, shape (B, obs_embed_size)
nonterm – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
A tuple
(posterior, prior)of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
Predict next state without observation (imagine mode).
In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.
- Parameters:
prev_state – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action – Previous action a_{t-1}, shape (B, action_size)
nonterm – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
deter: Predicted deterministic state
mean, std, stoch: Prior stochastic state distribution
- Return type:
Dictionary with predicted state containing
- get_prior(prev_state, prev_action, nonterm=1.0)[source]#
Compute prior distribution over stochastic state.
The prior represents the model’s belief about the stochastic state before observing the actual outcome.
- Parameters:
prev_state – Previous state dictionary
prev_action – Previous action
nonterm – Termination mask
- Returns:
Dictionary with prior state (no observation information)
- get_posterior(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
Compute posterior distribution over stochastic state.
The posterior incorporates observation information to produce a more accurate state estimate.
- Parameters:
prev_state – Previous state dictionary
prev_action – Previous action
obs_embed – Observation embedding
nonterm – Termination mask
- Returns:
Dictionary with posterior state (observation-informed). Note that the previous-state shape
(B, ...)is preserved; the batch dimension is not flattened.
- detach_state(state)[source]#
Detach state tensors from computation graph.
Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.
- Parameters:
state – State dictionary with tensor values
- Returns:
Detached state dictionary
- seq_to_batch(state_dict)[source]#
Convert sequence state to batch format.
- Parameters:
state_dict – Dictionary with sequence-dimension tensors (T, B, …)
- Returns:
Dictionary with batch-dimension tensors (B*T, …)
- observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#
Process a sequence of observations (observe mode rollout).
At each timestep we run
observe_steponce to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.- Parameters:
obs_embed – Observation embeddings, shape (T+1, B, obs_embed_size)
actions – Actions, shape (T, B, action_size)
nonterms – Non-termination flags, shape (T, B, 1)
init_state – Initial state dictionary
seq_len – Sequence length T
- Returns:
Dictionary with prior states stacked along the time axis. posterior: Dictionary with posterior states stacked along the time
axis.
- Return type:
prior
- imagine_rollout(policy, init_state, horizon)[source]#
Generate imagined trajectory using policy (imagine mode rollout).
- Parameters:
policy – Actor network that outputs actions from state features
init_state – Initial state dictionary
horizon – Number of steps to imagine
- Returns:
Dictionary with imagined states for each step
- forward(x, u)[source]#
Forward pass for training (computes sequence of states).
- Parameters:
x – Observations, shape (B, T+1, C, H, W)
u – Actions, shape (B, T, action_size)
- Returns:
List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)
- Return type:
states
- class world_models.models.rssm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#
Bases:
ModuleA Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.
- get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#
Returns the initial posterior given the observation.
- deterministic_state_fwd(h_t, s_t, a_t)[source]#
- Deterministic transition update that accepts:
a_t shaped [B, action_size]
a_t shaped [action_size] (unbatched) -> expanded to [B, action_size]
a_t shaped [B] or scalar -> reshaped appropriately
Ensures a_t is 2D and matches batch dimension of h_t before concatenation.
- state_prior(h_t, sample=False)[source]#
Returns the prior distribution over the latent state given the deterministic state
- state_posterior(h_t, e_t, sample=False)[source]#
Returns the state prior given the deterministic state and obs
- forward(x, u)[source]#
Forward through the RSSM for a batch of sequences. Inputs:
x: Tensor [B, T+1, C, H, W] (observations including initial frame) u: Tensor [B, T, action_size] (actions for T steps)
- Returns:
list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]
- Return type:
states
- class world_models.models.planet.Planet(env, bit_depth=5, device=None, state_size=200, latent_size=30, embedding_size=1024, memory_size=100, policy_cfg=None, headless=False, max_episode_steps=None, action_repeats=1, results_dir=None)[source]#
Bases:
objectHigh-level Planet wrapper.
- Usage example:
from world_models.models.planet import Planet p = Planet(env=’CartPole-v1’, bit_depth=5) p.train(epochs=50)
- warmup(n_episodes=1, random_policy=True)[source]#
Collect n_episodes of rollouts into memory (used as warmup).
- train(epochs=100, steps_per_epoch=150, batch_size=32, H=50, beta=1.0, save_every=25, record_grads=False, results_dir=None, scheduler_type='step', scheduler_kwargs=None)[source]#
High-level training loop. Delegates single-step training to the existing train function.
- Parameters:
scheduler_type (str) – Type of scheduler to use (“step”, “cosine”, “exponential”, “plateau”, None)
scheduler_kwargs (dict) – Additional arguments for the scheduler
Modular RSSM with swappable encoder/decoder/backbone components.
This module provides a flexible architecture for world model research, allowing researchers to easily swap different encoder, decoder, and backbone implementations for ablations and experimentation.
- class world_models.models.modular_rssm.EncoderBase(*args, **kwargs)[source]#
Bases:
Module,ABCAbstract base class for observation encoders.
- Parameters:
args (Any)
kwargs (Any)
- abstractmethod forward(obs)[source]#
Encode observations to embeddings.
- Parameters:
obs (Tensor)
- Return type:
Tensor
- embed_size: int#
- class world_models.models.modular_rssm.DecoderBase(*args, **kwargs)[source]#
Bases:
Module,ABCAbstract base class for observation decoders.
- Parameters:
args (Any)
kwargs (Any)
- class world_models.models.modular_rssm.BackboneBase(*args, **kwargs)[source]#
Bases:
Module,ABCAbstract base class for recurrent dynamics backbones.
- Parameters:
args (Any)
kwargs (Any)
- abstractmethod forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Process one step of dynamics. Returns (prior, posterior).
- Parameters:
state (Dict[str, Tensor])
action (Tensor)
obs_embed (Tensor | None)
nonterm (float)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- abstractmethod init_state(batch_size, device)[source]#
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- stoch_size: int#
- deter_size: int#
- class world_models.models.modular_rssm.ConvEncoder(input_shape, embed_size, activation='elu', depth=32)[source]#
Bases:
EncoderBaseConvolutional encoder from Dreamer (image observations).
- Parameters:
input_shape (Tuple[int, int, int])
embed_size (int)
activation (str)
depth (int)
- class world_models.models.modular_rssm.MLPEncoder(input_dim, embed_size, hidden_sizes=[256, 256], activation='elu')[source]#
Bases:
EncoderBaseMLP encoder for state-based observations.
- Parameters:
input_dim (int)
embed_size (int)
hidden_sizes (List[int])
activation (str)
- class world_models.models.modular_rssm.ViTEncoder(input_shape, embed_size, patch_size=8, depth=6, num_heads=8, mlp_ratio=4.0, activation='gelu')[source]#
Bases:
EncoderBaseVision Transformer encoder for image observations.
- Parameters:
input_shape (Tuple[int, int, int])
embed_size (int)
patch_size (int)
depth (int)
num_heads (int)
mlp_ratio (float)
activation (str)
- class world_models.models.modular_rssm.TransformerBlock(embed_size, num_heads, mlp_ratio, activation)[source]#
Bases:
ModuleTransformer block for ViT encoder.
- Parameters:
embed_size (int)
num_heads (int)
mlp_ratio (float)
activation (str)
- class world_models.models.modular_rssm.ConvDecoder(stoch_size, deter_size, output_shape, activation='elu', depth=32)[source]#
Bases:
DecoderBaseConvolutional decoder for image observations.
- Parameters:
stoch_size (int)
deter_size (int)
output_shape (Tuple[int, int, int])
activation (str)
depth (int)
- class world_models.models.modular_rssm.MLPDecoder(stoch_size, deter_size, output_dim, hidden_sizes=[256, 256], activation='elu', dist='normal')[source]#
Bases:
DecoderBaseMLP decoder for state-based observations.
- Parameters:
stoch_size (int)
deter_size (int)
output_dim (int)
hidden_sizes (List[int])
activation (str)
dist (str)
- class world_models.models.modular_rssm.GRUBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#
Bases:
BackboneBaseGRU-based recurrent dynamics backbone (standard RSSM).
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
hidden_size (int)
embed_size (int)
activation (str)
- property embedding_size: int#
- class world_models.models.modular_rssm.LSTMBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#
Bases:
BackboneBaseLSTM-based recurrent dynamics backbone.
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
hidden_size (int)
embed_size (int)
activation (str)
- property embedding_size: int#
- class world_models.models.modular_rssm.TransformerBackbone(action_size, stoch_size, deter_size, embed_size, num_heads=4, num_layers=2, activation='gelu')[source]#
Bases:
BackboneBaseTransformer-based dynamics backbone for long-range dependencies.
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
embed_size (int)
num_heads (int)
num_layers (int)
activation (str)
- property embedding_size: int#
- class world_models.models.modular_rssm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#
Bases:
ModuleModular RSSM with swappable encoder, decoder, and backbone.
This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer
Example
>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024) >>> decoder = ConvDecoder(32, 200, (3, 64, 64)) >>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024) >>> rssm = ModularRSSM(encoder, decoder, backbone)
- Parameters:
encoder (EncoderBase)
decoder (DecoderBase)
backbone (BackboneBase)
reward_decoder (DecoderBase | None)
- property stoch_size: int#
- property deter_size: int#
- property embed_size: int#
- init_state(batch_size, device)[source]#
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
obs (Tensor)
nonterm (Any)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
nonterm (Any)
- Return type:
Dict[str, Tensor]
- observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
- Parameters:
obs (Tensor)
actions (Tensor)
nonterms (Tensor)
prev_state (Dict[str, Tensor])
horizon (int)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- world_models.models.modular_rssm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#
Factory function to create a modular RSSM with specified components.
- Parameters:
encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)
decoder_type (str) – Type of decoder (“conv”, “mlp”)
backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)
obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state
action_size (int) – Action space dimension
stoch_size (int) – Stochastic latent dimension
deter_size (int) – Deterministic hidden dimension
embed_size (int) – Encoder embedding dimension
hidden_size (int) – Hidden layer dimension
activation (str) – Activation function name
- Returns:
Configured ModularRSSM instance
- Return type:
- class world_models.models.jepa_agent.JEPAAgent(config=None, **kwargs)[source]#
Bases:
objectConvenience interface for configuring and launching JEPA training runs.
Accepts a JEPAConfig plus keyword overrides, prepares output folders, and delegates execution to the JEPA training entrypoint.
- Parameters:
config (JEPAConfig | None)
- world_models.models.vit.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Generate fixed 2D sine/cosine positional embeddings on a square patch grid.
Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding.
- world_models.models.vit.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)[source]#
Build 2D sine/cosine embeddings from precomputed meshgrid coordinates.
The final embedding concatenates independent encodings for vertical and horizontal coordinates.
- world_models.models.vit.get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Generate 1D sine/cosine positional embeddings for integer positions.
Useful for sequence-style positional encoding and as a building block for 2D embedding construction.
- world_models.models.vit.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#
Generate 1D sine/cosine positional embeddings from explicit positions.
Positions are projected onto a log-frequency basis and encoded with sine and cosine components.
- world_models.models.vit.drop_path(x, drop_prob=0.0, training=False)[source]#
Apply stochastic depth (DropPath) regularization to residual branches.
Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude.
- Parameters:
drop_prob (float)
training (bool)
- class world_models.models.vit.DropPath(drop_prob=None)[source]#
Bases:
ModuleModule wrapper around the functional drop_path stochastic depth utility.
- class world_models.models.vit.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#
Bases:
ModuleTwo-layer feed-forward network used inside transformer blocks.
Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern.
- class world_models.models.vit.Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
ModuleMulti-head self-attention block for token sequences.
Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout.
- class world_models.models.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleTransformer encoder block combining attention and MLP residual branches.
Each branch uses pre-normalization and optional stochastic depth.
- class world_models.models.vit.PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768)[source]#
Bases:
ModuleImage to Patch Embedding
- class world_models.models.vit.ConvEmbed(channels, strides, img_size=224, in_chans=3, batch_norm=True)[source]#
Bases:
Module3x3 Convolution stems for ViT following ViTC models
- class world_models.models.vit.VisionTransformerPredictor(num_patches, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#
Bases:
ModuleVision Transformer
- class world_models.models.vit.VisionTransformer(img_size=[224], patch_size=16, in_chans=3, embed_dim=768, predictor_embed_dim=384, depth=12, predictor_depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#
Bases:
ModuleVision Transformer
- world_models.models.vit.vit_predictor(**kwargs)[source]#
Factory for a JEPA predictor transformer with sensible defaults.
- world_models.models.vit.vit_tiny(patch_size=16, **kwargs)[source]#
Factory for a tiny Vision Transformer encoder backbone.
- world_models.models.vit.vit_small(patch_size=16, **kwargs)[source]#
Factory for a small Vision Transformer encoder backbone.
- world_models.models.vit.vit_base(patch_size=16, **kwargs)[source]#
Factory for a base Vision Transformer encoder backbone.
- world_models.models.vit.vit_large(patch_size=16, **kwargs)[source]#
Factory for a large Vision Transformer encoder backbone.
- world_models.models.vit.vit_huge(patch_size=16, **kwargs)[source]#
Factory for a huge Vision Transformer encoder backbone.
- world_models.models.vit.vit_giant(patch_size=16, **kwargs)[source]#
Factory for a giant Vision Transformer encoder backbone.
- world_models.models.iris_agent.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#
Compute λ-return target for value function training.
- Parameters:
rewards (Tensor) – Rewards (B, T)
values (Tensor) – Value estimates (B, T+1)
discounts (Tensor) – Discount factors (B, T)
lambda_coef (float) – Lambda parameter for bootstrapping
- Returns:
λ-return targets (B, T)
- Return type:
- class world_models.models.iris_agent.IRISAgent(config, action_size, device)[source]#
Bases:
ModuleComplete IRIS Agent with world model and policy.
Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning
- Parameters:
config (IRISConfig)
action_size (int)
device (device)
- forward_actor_critic(frames, hidden=None)[source]#
Forward pass through actor-critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state
- Returns:
(B, T, action_size) values: (B, T) hidden_state: (h, c)
- Return type:
action_logits
- act(frame, epsilon=0.0, temperature=1.0)[source]#
Sample action from policy.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
epsilon (float) – Random action probability
temperature (float) – Action distribution temperature
- Returns:
Selected actions (B,)
- Return type:
- imagine_rollout(initial_frame, horizon=20)[source]#
Generate imagined trajectories using world model.
- Parameters:
initial_frame (Tensor) – Starting frame (B, C, H, W)
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined rollout data
- Return type:
trajectory
- update_autoencoder(frames)[source]#
Update discrete autoencoder.
- Parameters:
frames (Tensor) – Training frames (B, C, H, W)
- Returns:
Dictionary of loss values
- Return type:
losses
- update_transformer(frames, actions, rewards, terminals)[source]#
Update transformer world model.
- Parameters:
frames (Tensor) – Frame sequence
actions (Tensor) – Actions taken
rewards (Tensor) – Rewards received
terminals (Tensor) – Terminal flags
- Returns:
Dictionary of loss values
- Return type:
losses
- class world_models.models.iris_transformer.IRISTransformer(vocab_size=512, tokens_per_frame=16, action_size=18, embed_dim=256, num_layers=10, num_heads=4, dropout=0.1)[source]#
Bases:
ModuleAutoregressive Transformer for world modeling.
Models the dynamics of the environment by predicting: - Next frame tokens (transition model) - Rewards - Episode termination
The Transformer operates on sequences of interleaved frame tokens and actions.
- Parameters:
vocab_size (int)
tokens_per_frame (int)
action_size (int)
embed_dim (int)
num_layers (int)
num_heads (int)
dropout (float)
- forward(tokens, actions, mask=None)[source]#
Forward pass through the Transformer world model.
- Parameters:
tokens (Tensor) – Frame tokens (B, T, K) where T is timesteps
actions (Tensor) – Actions (B, T)
mask (Tensor | None) – Optional attention mask
- Returns:
Next token predictions (B, T, K, vocab_size) rewards: Predicted rewards (B, T) terminations: Predicted terminations (B, T, 2)
- Return type:
token_logits
- predict_next_tokens(tokens, actions)[source]#
Predict the next frame tokens autoregressively.
Used during imagination rollouts.
- Parameters:
tokens (Tensor) – Current frame tokens (B, K)
actions (Tensor) – Actions taken (B,)
- Returns:
Next frame token predictions (B, K, vocab_size) action_hidden: Hidden states for reward prediction (B, embed_dim)
- Return type:
token_logits
- sample_next_tokens(tokens, actions, temperature=1.0)[source]#
Sample next tokens from the distribution.
- Parameters:
tokens (Tensor) – Current frame tokens (B, K)
actions (Tensor) – Actions taken (B,)
temperature (float) – Sampling temperature (higher = more random)
- Returns:
Sampled token indices (B, K) log_probs: Log probabilities of sampled tokens (B, K)
- Return type:
sampled_tokens
- class world_models.models.iris_transformer.IRISWorldModel(encoder, decoder, transformer)[source]#
Bases:
ModuleComplete IRIS World Model combining autoencoder and transformer.
This is the core component that learns environment dynamics entirely in the “imaginary” latent space.
- Parameters:
encoder (Module)
decoder (Module)
transformer (IRISTransformer)
- forward(observations, actions)[source]#
Full world model forward pass.
- Parameters:
observations (Tensor) – Image sequence (B, T+1, C, H, W)
actions (Tensor) – Actions (B, T)
- Returns:
Dictionary with predicted tokens, rewards, terminations losses: Dictionary with loss components
- Return type:
predictions
- imagine(initial_tokens, policy, horizon=20, temperature=1.0)[source]#
Generate imagined trajectories.
- Parameters:
initial_tokens (Tensor) – Initial frame tokens (B, K)
policy (Module) – Policy network to sample actions
horizon (int) – Number of steps to imagine
temperature (float) – Sampling temperature for token prediction
- Returns:
Dictionary with imagined trajectories
- Return type:
imagined
- class world_models.models.genie.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleGenie: Generative Interactive Environment.
A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions
Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).
Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)
The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse
At inference, latent actions are stopgrad’d when passed to dynamics model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_decoder_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
encoder_depth (int)
decoder_depth (int)
latent_action_depth (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- forward(video, mask_prob=0.5, training_phase='all')[source]#
Full forward pass through all components.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing losses and predictions
- Return type:
Dict[str, Tensor]
- training_step(video, mask_prob=0.5, training_phase='all')[source]#
Single training step computing all losses.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing all losses for backpropagation
- Return type:
Dict[str, Tensor]
- encode_video(video)[source]#
Encode video to discrete tokens.
- Parameters:
video (Tensor) – (B, C, T, H, W)
- Returns:
(B, T, H*W)
- Return type:
video_tokens
- infer_actions(frames)[source]#
Infer latent actions from a sequence of frames.
- Parameters:
frames (Tensor) – (B, C, T, H, W) video frames
- Returns:
(B, T-1) inferred latent action indices
- Return type:
latent_actions
- generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#
Generate video frames given a prompt frame and actions.
- Parameters:
prompt_frame (Tensor) – (B, C, H, W) initial frame
num_frames (int) – Total number of frames to generate
actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random
use_maskgit (bool) – Whether to use MaskGIT sampling
- Returns:
(B, C, num_frames, H, W)
- Return type:
generated_video
- play(current_frame, action, current_frames=None)[source]#
Play step - generate next frame given current frame and action.
- Parameters:
current_frame (Tensor) – (B, C, H, W) current frame
action (Tensor) – (B,) latent action indices
current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame
- Returns:
(B, C, H, W)
- Return type:
next_frame
- world_models.models.genie.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.models.genie.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Create a smaller Genie model for development/testing.
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.models.genie.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#
Create the full 11B parameter Genie model (approximate).
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- class world_models.models.latent_action_model.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleLatent Action Model (LAM) for unsupervised action learning.
Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.
Based on Genie paper - learns actions without action labels from Internet videos.
Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- encode(x_prev, x_next)[source]#
Encode frames to latent actions.
- Parameters:
x_prev (Tensor) – Previous frames (B, C, T, H, W)
x_next (Tensor) – Next frame (B, C, H, W)
- Returns:
Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)
- Return type:
latent_actions
- world_models.models.latent_action_model.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- class world_models.models.dynamics_model.MaskGITSampler(num_steps=25, temperature=2.0, mask_schedule='cosine')[source]#
Bases:
objectMaskGIT sampling for token-based video generation.
Uses iterative refinement with a mask schedule to progressively reveal tokens during generation.
- Parameters:
num_steps (int)
temperature (float)
mask_schedule (str)
- class world_models.models.dynamics_model.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
ModuleDynamics Model for action-controllable video generation.
A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.
Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- forward(video_tokens, actions, mask_prob=0.0)[source]#
Forward pass for training.
- Parameters:
video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T
actions (Tensor) – (B, T) - latent action indices for frames 1 to T
mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)
- Returns:
(B, T, H*W, vocab_size)
- Return type:
logits
- sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#
Sample future frames using MaskGIT.
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
sampler (MaskGITSampler | None) – MaskGIT sampler instance
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#
Simple autoregressive sampling (token by token).
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
temperature (float) – Sampling temperature
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- world_models.models.dynamics_model.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#
Factory function to create a Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
- Return type:
Diffusion and DIAMOND components#
Key classes: DDPM, DiT, DiffusionUNet, EDMPreconditioner, EulerSampler, RewardTerminationModel, and ActorCriticNetwork.
Diffusion sub-module - Diffusion model components for world models.
- Exported Components:
DiT: Diffusion Transformer model
PatchEmbed: Image patch embedding
PatchUnEmbed: Patch unembedding (decode tokens to image)
DDPM: Denoising Diffusion Probabilistic Model implementation
ActorCriticNetwork: DIAMOND actor-critic network
RewardTerminationModel: Reward/termination prediction model
sinusoidal_time_embedding: Time embedding for diffusion models
- class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end, device)[source]#
Bases:
objectUtility class implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]#
Create sinusoidal timestep embeddings for diffusion conditioning.
This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.
- Math:
embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)
- Parameters:
timesteps – Tensor of integer timesteps, shape (B,) or (B, 1)
dim – Embedding dimension (must be even)
- Returns:
Tensor of shape (B, dim) with sinusoidal embeddings
- Usage with DiT:
t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)
# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization
- class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.
- Process:
Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches
Each patch is projected to embed_dim via linear layer (Conv2d)
Learnable positional embeddings are added for spatial information
Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)
- Parameters:
img_size – Image size (assumes square), e.g., 32 for CIFAR
patch_size – Size of each patch (typically 4, 8, or 16)
in_channels – Number of input channels (3 for RGB)
embed_dim – Output dimension for each patch token
- Usage with DiT:
patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4
- class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]#
Bases:
ModuleConditioned transformer block used inside the DiT backbone.
Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.
- class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
- class world_models.models.diffusion.diamond_diffusion.AdaptiveGroupNorm(num_groups, num_channels, cond_dim)[source]#
Bases:
ModuleAdaptive Group Normalization that conditions on actions and diffusion time.
- Parameters:
num_groups (int)
num_channels (int)
cond_dim (int)
- class world_models.models.diffusion.diamond_diffusion.ResBlock(in_channels, out_channels, cond_dim, dropout=0.0)[source]#
Bases:
ModuleResidual block with adaptive group normalization.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
dropout (float)
- class world_models.models.diffusion.diamond_diffusion.AttentionBlock(channels, cond_dim)[source]#
Bases:
ModuleSelf-attention block for U-Net.
- Parameters:
channels (int)
cond_dim (int)
- class world_models.models.diffusion.diamond_diffusion.TimestepEmbedding(dim, freq_dim=256)[source]#
Bases:
ModuleSinusoidal timestep embedding.
- Parameters:
dim (int)
freq_dim (int)
- class world_models.models.diffusion.diamond_diffusion.DownBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#
Bases:
ModuleDownsampling block for U-Net encoder.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
num_res_blocks (int)
attention (bool)
- class world_models.models.diffusion.diamond_diffusion.UpBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#
Bases:
ModuleUpsampling block for U-Net decoder.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
num_res_blocks (int)
attention (bool)
- class world_models.models.diffusion.diamond_diffusion.DiffusionUNet(obs_channels=3, num_conditioning_frames=4, base_channels=64, channel_multipliers=(1, 1, 1, 1), num_res_blocks=2, cond_dim=256, action_dim=18)[source]#
Bases:
ModuleU-Net architecture for EDM diffusion world model. Uses frame stacking for observation conditioning and adaptive group norm for action conditioning.
- Parameters:
obs_channels (int)
num_conditioning_frames (int)
base_channels (int)
channel_multipliers (Tuple[int, ...])
num_res_blocks (int)
cond_dim (int)
action_dim (int)
- forward(x, t, obs_history, actions)[source]#
Forward pass of the diffusion model.
- Parameters:
x (Tensor) – Noised observation at timestep t [B, C, H, W]
t (Tensor) – Diffusion timestep [B]
obs_history (Tensor) – Past observations for conditioning [B, L, C, H, W]
actions (Tensor) – Past actions [B, L]
- Returns:
Predicted clean observation [B, C, H, W]
- Return type:
Tensor
- class world_models.models.diffusion.diamond_diffusion.EDMPreconditioner(sigma_data=0.5, p_mean=-0.4, p_std=1.2)[source]#
Bases:
objectEDM preconditioner following Karras et al. (2022).
- Parameters:
sigma_data (float)
p_mean (float)
p_std (float)
- get_preconditioners(sigma)[source]#
Compute EDM preconditioners for given noise levels.
- Returns:
Dictionary with c_skip, c_out, c_in, c_noise
- Parameters:
sigma (Tensor)
- Return type:
dict
- sample_noise_level(batch_size, device)[source]#
Sample noise level from log-normal distribution.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tensor
- denoise(model, x, sigma, **kwargs)[source]#
Apply EDM denoising with preconditioners.
- Parameters:
model – Diffusion model
x (Tensor) – Noised input [B, C, H, W]
sigma (Tensor) – Noise level [B]
**kwargs – Additional conditioning (obs_history, actions)
- Returns:
Denoised prediction [B, C, H, W]
- Return type:
Tensor
- class world_models.models.diffusion.diamond_diffusion.EulerSampler(sigma_min=0.002, sigma_max=80.0, rho=7, num_steps=3, edm_precond=None)[source]#
Bases:
objectEuler method sampler for reverse diffusion.
- Parameters:
sigma_min (float)
sigma_max (float)
rho (int)
num_steps (int)
edm_precond (EDMPreconditioner | None)
- sample(model, shape, device, obs_history=None, actions=None)[source]#
Generate samples using Euler method.
- Parameters:
model (Module) – Diffusion model
shape (Tuple[int, ...]) – Output shape [B, C, H, W]
device (device) – Device to run on
obs_history (Tensor | None) – Conditioning observations [B, L, C, H, W]
actions (Tensor | None) – Conditioning actions [B, L]
- Returns:
Generated samples [B, C, H, W]
- Return type:
Tensor
- class world_models.models.diffusion.reward_termination.ConvBlock(in_channels, out_channels, cond_dim, stride=2)[source]#
Bases:
ModuleConvolutional block with adaptive group normalization.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
stride (int)
- class world_models.models.diffusion.reward_termination.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#
Bases:
ModuleReward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.
- Parameters:
obs_channels (int) – Number of observation channels (3 for RGB)
action_dim (int) – Number of possible actions
channels (Tuple[int, ...]) – List of channel sizes for conv blocks
lstm_dim (int) – LSTM hidden dimension
cond_dim (int) – Conditioning dimension for adaptive norm
- forward(obs, actions, hidden_state=None)[source]#
Forward pass of reward/termination model.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
actions (Tensor) – Actions [B, T]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states
- Return type:
reward_logits
- predict(obs, actions, hidden_state=None)[source]#
Predict reward and termination for a single step.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
actions (Tensor) – Single action [B]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states
- Return type:
reward
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class world_models.models.diffusion.reward_termination.RewardTerminationLoss[source]#
Bases:
ModuleLoss function for reward and termination prediction.
- forward(reward_logits, termination_logits, rewards, terminated)[source]#
Compute loss for reward and termination predictions.
- Parameters:
reward_logits (Tensor) – [B, T, 3]
termination_logits (Tensor) – [B, T, 2]
rewards (Tensor) – Rewards as class indices [B, T] (values -1, 0, 1 mapped to 0, 1, 2)
terminated (Tensor) – Termination flags [B, T]
- Returns:
total_loss, reward_loss, termination_loss
- Return type:
Tuple[Tensor, Tensor, Tensor]
- class world_models.models.diffusion.actor_critic.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#
Bases:
ModuleActor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.
- Parameters:
obs_channels (int)
action_dim (int)
channels (Tuple[int, ...])
lstm_dim (int)
- forward(obs, hidden_state=None)[source]#
Forward pass of actor-critic network.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)
- Return type:
policy_logits
- get_action(obs, hidden_state=None, deterministic=False)[source]#
Get action from a single observation.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
deterministic (bool) – If True, take argmax; else sample
- Returns:
Selected action [B] hidden_state: (h, c)
- Return type:
action
- get_actions(obs, hidden_state=None, deterministic=False)[source]#
Batched version of get_action.
- Parameters:
obs (Tensor) – Tensor of shape [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size
deterministic (bool) – If True, take argmax; else sample from policy
- Returns:
LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple
- Return type:
- get_value(obs, hidden_state=None)[source]#
Get value for a single observation.
- Parameters:
obs (Tensor)
hidden_state (Tuple[Tensor, Tensor] | None)
- Return type:
Tuple[Tensor, Tuple[Tensor, Tensor] | None]
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
Get LSTM hidden size.
- Return type:
int
- class world_models.models.diffusion.actor_critic.RLLoss(discount_factor=0.985, lambda_returns=0.95, entropy_weight=0.001)[source]#
Bases:
ModuleRL loss functions for DIAMOND. Implements REINFORCE with value baseline and λ-returns.
- Parameters:
discount_factor (float)
lambda_returns (float)
entropy_weight (float)
- compute_lambda_returns(rewards, values, dones)[source]#
Compute λ-returns.
- Parameters:
rewards (Tensor) – [B, T]
values (Tensor) – [B, T+1]
dones (Tensor) – [B, T]
- Returns:
[B, T]
- Return type:
- policy_loss(policy_logits, actions, lambda_returns, values)[source]#
Compute policy loss with REINFORCE and entropy regularization.
- Parameters:
policy_logits (Tensor) – [B, T, A]
actions (Tensor) – [B, T]
lambda_returns (Tensor) – [B, T]
values (Tensor) – [B, T+1]
- Returns:
scalar
- Return type:
policy_loss
Vision, tokenization, and layers#
Key classes: ConvEncoder, ConvDecoder, DenseDecoder, ActionDecoder, CNNEncoder, CNNDecoder, IRISEncoder, IRISDecoder, DiscreteAutoencoder, VectorQuantizer, VectorQuantizerEMA, VideoTokenizer, MultiHeadSelfAttention, and STTransformer.
- class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).
- Architecture:
Input: (B, C, H, W) raw images, values in [-0.5, 0.5] Process: 4 convolutional layers with stride 2, halving spatial dimensions Output: (B, embed_size) compact representation
The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.
- Usage with Dreamer:
- encoder = ConvEncoder(
input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation=’relu’ # Activation function
) obs_embedding = encoder(observation) # (B, 256)
- Parameters:
input_shape – Tuple (C, H, W) for input images, typically (3, 64, 64)
embed_size – Output embedding dimension, typically 256 or 1024
activation – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)
depth – Base channel depth for first layer (default 32)
- class world_models.vision.dreamer_decoder.TanhBijector[source]#
Bases:
TransformBijective tanh transform for squashing Gaussian distributions to [-1, 1].
This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:
Bijective mapping: tanh is invertible (with atanh as inverse)
Stable log-det Jacobian: Computable for gradient-based training
Clipped actions: During inference, actions are naturally bounded
- Math:
Forward: y = tanh(x) Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y)) Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))
- Usage with Dreamer ActionDecoder:
- dist = TransformedDistribution(
Normal(mean, std), TanhBijector()
) action = dist.sample() # Bounded to [-1, 1]
- Reference:
Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.
- property sign#
- class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#
Bases:
ModuleConvolutional decoder for reconstructing observations from latent states.
Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.
- Architecture:
Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter) Process: Dense projection + 4 transposed convolutions (upsampling 2x each) Output: Independent Normal distribution over observation pixels
The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.
- Output Distribution:
Returns torch.distributions.Independent(Normal(mean, std), len(shape)) This allows computing log_prob(observation) for reconstruction loss.
- Usage in Dreamer world model:
- decoder = ConvDecoder(
stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation=’relu’
) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)
- Training:
The reconstruction loss is: -log_prob(observation) This encourages the RSSM to learn states that capture observation information.
- class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.
- Architecture:
Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter) Process: MLP with configurable layers and hidden units Output: Predicted quantity with distribution (normal, binary, or raw)
- Supports three output types:
‘normal’: Gaussian distribution for regression (rewards, values)
‘binary’: Bernoulli distribution for binary classification (discount)
‘none’: Raw tensor for non-probabilistic outputs
- Usage:
- reward_decoder = DenseDecoder(
stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’normal’
) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)
- For discount prediction (binary):
- discount_decoder = DenseDecoder(
stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation=’elu’, dist=’binary’ # Bernoulli for P(continue)
)
- class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]#
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- property name#
- class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
- class world_models.vision.planet_encoder.CNNEncoder(embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) encoder for processing image inputs.
- class world_models.vision.planet_decoder.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) decoder for reconstructing image outputs.
- class world_models.vision.iris_encoder.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Encoder for IRIS discrete autoencoder.
Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.
- Architecture:
4 convolutional layers with residual blocks
Self-attention at 8x8 and 16x16 resolutions
Vector quantization to produce discrete tokens
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
in_channels (int)
base_channels (int)
num_residual_blocks (int)
frame_shape (Tuple[int, int, int])
- forward(x)[source]#
Encode images to discrete tokens.
- Parameters:
x (Tensor) – Input images (B, C, H, W) - should be 64x64
- Returns:
Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- class world_models.vision.iris_encoder.ResidualBlock(channels)[source]#
Bases:
ModuleResidual block for encoder.
- Parameters:
channels (int)
- class world_models.vision.iris_encoder.SelfAttentionBlock(channels)[source]#
Bases:
ModuleSelf-attention block for encoder.
Applies spatial self-attention to capture long-range dependencies.
- Parameters:
channels (int)
- class world_models.vision.iris_decoder.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Decoder for IRIS discrete autoencoder.
Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.
- Parameters:
vocab_size (int)
embedding_dim (int)
base_channels (int)
out_channels (int)
frame_shape (Tuple[int, int, int])
- forward(z)[source]#
Decode tokens to images.
- Parameters:
z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)
- Returns:
Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)
- Return type:
reconstructed
- class world_models.vision.iris_decoder.UpsampleBlock(in_channels, mid_channels, out_channels)[source]#
Bases:
ModuleUpsampling block with optional residual connection.
- Parameters:
in_channels (int)
mid_channels (int)
out_channels (int)
- class world_models.vision.iris_decoder.ResidualBlock(channels)[source]#
Bases:
ModuleResidual block for decoder.
- Parameters:
channels (int)
- class world_models.vision.iris_decoder.DiscreteAutoencoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, base_channels=64, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleComplete Discrete Autoencoder combining encoder and decoder.
Used for training the VQVAE component of IRIS.
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
base_channels (int)
frame_shape (Tuple[int, int, int])
- class world_models.vision.vq_layer.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#
Bases:
ModuleVector Quantizer for discrete autoencoder.
Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)
Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
- class world_models.vision.vq_layer.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#
Bases:
ModuleVector Quantizer with Exponential Moving Average updates.
Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
ema_decay (float)
epsilon (float)
- class world_models.vision.video_tokenizer.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#
Bases:
ModuleVideo Tokenizer using VQ-VAE with Spatiotemporal Transformer.
This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.
The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.
- Architecture:
Patch Embedding: Convert (B, C, T, H, W) video to patch tokens
Encoder ST-Transformer: Process spatial-temporal patches
Vector Quantization: Discretize continuous embeddings to codebook entries
Decoder ST-Transformer: Reconstruct video from quantized tokens
Patch Unembedding: Convert tokens back to video frames
- Key Features:
Causal processing: Each frame’s encoding only uses previous frames
Discrete tokens: Enables autoregressive prediction with latent actions
Memory efficient: Uses ST-Transformer instead of full ViT to reduce O(n²) complexity
- Usage with Genie:
- tokenizer = VideoTokenizer(
num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32
) reconstructed, indices, loss_dict = tokenizer(video_frames)
# For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)
- Training:
The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings (encourages learning useful codes) - Commitment loss: Penalizes encoder outputs drifting from codebook
- Reference:
Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
use_ema (bool)
ema_decay (float)
- encode(x)[source]#
Encode video to discrete tokens.
- Parameters:
x (Tensor) – Video tensor (B, C, T, H, W)
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- decode_indices(indices)[source]#
Decode token indices to embeddings for video frames.
- Parameters:
indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’*W’
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim)
- Return type:
z_q
- world_models.vision.video_tokenizer.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#
Factory function to create a Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
- Return type:
Blocks sub-module - Transformer blocks and attention mechanisms.
- Exported Components:
- Transformers:
STTransformer: Spatiotemporal Transformer for video processing
STSpatialAttention: Spatial attention layer
STTemporalAttention: Temporal attention layer
STBlock: Combined spatiotemporal transformer block
- Attention:
MultiHeadSelfAttention: Multi-head self-attention
MultiHeadAttention: Generic multi-head attention (alias)
Attention: Attention mechanism
- Normalization:
RMSNorm: Root Mean Square Layer Normalization
AdaLNNormalization: Adaptive Layer Normalization
- class world_models.blocks.mhsa.MultiHeadSelfAttention(d, n_heads=2)[source]#
Bases:
ModuleMulti-head scaled dot-product self-attention over sequence tokens.
This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.
- class world_models.blocks.st_transformer.STSpatialAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
ModuleSpatial attention layer for spatiotemporal transformer.
Processes video tokens by attending over spatial positions (H*W) within each time step independently. Captures within-frame spatial relationships.
Input: (B, T, N, C) - B batches, T time steps, N spatial positions (H*W), C channels Output: (B, T, N, C) - Same shape, spatially attended features
- Architecture:
QKV projection: Linear(dim, dim*3) Reshape to multi-head attention format Attention: softmax(Q @ K^T / sqrt(d_k)) @ V Output projection
- Usage in ST-Transformer:
Applied to video tokens of shape (B, T, N, C) to capture within-frame spatial structure (e.g., object positions).
- Parameters:
dim (int)
num_heads (int)
qkv_bias (bool)
qk_scale (float | None)
attn_drop (float)
proj_drop (float)
- class world_models.blocks.st_transformer.STTemporalAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
ModuleTemporal attention layer with causal masking for spatiotemporal transformer.
Processes video tokens by attending over time steps (T) across all spatial positions. Uses causal masking to ensure each frame only attends to previous frames (important for autoregressive video generation).
Input: (B, T, N, C) - B batches, T time steps, N spatial positions, C channels Output: (B, T, N, C) - Same shape, temporally attended features
- Key Feature: Causal masking
Frame t can only attend to frames 0…t-1
Prevents information leakage from future frames
Essential for autoregressive video generation models
- Usage in Genie VideoTokenizer:
Applied after STSpatialAttention to model temporal dynamics. The causal mask ensures generation is autoregressive.
- Parameters:
dim (int)
num_heads (int)
qkv_bias (bool)
qk_scale (float | None)
attn_drop (float)
proj_drop (float)
- class world_models.blocks.st_transformer.STMLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#
Bases:
ModuleMLP for ST-Transformer block.
- Parameters:
in_features (int)
hidden_features (int | None)
out_features (int | None)
act_layer (type[Module])
drop (float)
- class world_models.blocks.st_transformer.STTransformerBlock(dim, num_heads=8, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleCombined spatiotemporal transformer block with interleaved attention.
- A single block applies:
Spatial attention (within each time frame)
Temporal attention (across frames with causal mask)
MLP projection
The order is: x -> + SpatialAttn -> + TemporalAttn -> + MLP -> x
This interleaved design captures both spatial structure and temporal dynamics efficiently, used in Genie’s VideoTokenizer and DynamicsModel.
- Parameters:
dim (int) – Feature dimension (must match patch embedding dimension)
num_heads (int) – Number of attention heads
mlp_ratio (float) – MLP hidden dim = dim * mlp_ratio
drop (float) – Dropout rates
attn_drop (float) – Dropout rates
drop_path (float) – Stochastic depth rate for drop path regularization
norm_layer (type[Module]) – Normalization layer class (default: nn.LayerNorm)
qkv_bias (bool)
qk_scale (float | None)
act_layer (type[Module])
- Usage in Genie:
# VideoTokenizer encoder (12 layers) encoder = STTransformer(
num_frames=16, num_patches_per_frame=256, # 16x16 for 64x64 images with patch_size=4 dim=512, depth=12, num_heads=16
) encoded = encoder(tokens) # (B, T*N, C)
# Dynamics model decoder (24 layers) decoder = STTransformer(
num_frames=16, num_patches_per_frame=256, dim=1024, depth=24, num_heads=16
) decoded = decoder(tokens)
- class world_models.blocks.st_transformer.DropPath(drop_prob=0.0)[source]#
Bases:
ModuleDrop paths (Stochastic Depth) per sample.
- Parameters:
drop_prob (float)
- class world_models.blocks.st_transformer.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleSpatiotemporal Transformer for video modeling.
Contains L spatiotemporal blocks with interleaved spatial and temporal attention.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
- world_models.blocks.st_transformer.create_st_transformer(num_frames=16, patch_size=4, img_size=64, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Factory function to create an ST-Transformer.
- Parameters:
num_frames (int)
patch_size (int)
img_size (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- Return type:
Configuration objects#
Lazy config exports.
Configuration modules can have optional training dependencies, so the package initializer avoids importing every config eagerly.
- class world_models.configs.DreamerConfig[source]#
Bases:
objectConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- class world_models.configs.JEPAConfig[source]#
Bases:
objectMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
objectDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- world_models.configs.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)
- class world_models.configs.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
object- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
- class world_models.configs.IRISConfig[source]#
Bases:
objectConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class world_models.configs.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
objectConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.configs.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#
Bases:
objectSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- class world_models.configs.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.configs.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
objectConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class world_models.configs.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
objectConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class world_models.configs.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.configs.dreamer_config.DreamerConfig[source]#
Bases:
objectConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- class world_models.configs.jepa_config.JEPAConfig[source]#
Bases:
objectMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- class world_models.configs.iris_config.IRISConfig[source]#
Bases:
objectConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class world_models.configs.genie_config.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
objectConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.configs.genie_config.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0)[source]#
Bases:
objectSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- class world_models.configs.genie_config.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.configs.genie_config.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
objectConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class world_models.configs.genie_config.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
objectConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class world_models.configs.genie_config.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
objectConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.configs.dit_config.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
objectDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- world_models.configs.dit_config.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)
- class world_models.configs.diamond_config.ModelPreset(diffusion_channels, diffusion_res_blocks, diffusion_cond_dim, reward_channels, reward_lstm_dim, actor_channels, actor_lstm_dim)[source]#
Bases:
objectModel architecture preset for different hardware tiers.
- Parameters:
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
reward_channels (List[int])
reward_lstm_dim (int)
actor_channels (List[int])
actor_lstm_dim (int)
- diffusion_channels: List[int]#
- diffusion_res_blocks: int#
- diffusion_cond_dim: int#
- reward_channels: List[int]#
- reward_lstm_dim: int#
- actor_channels: List[int]#
- actor_lstm_dim: int#
- class world_models.configs.diamond_config.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
object- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
Training entry points#
- world_models.training.train_jepa.main(args, resume_preempt=False)[source]#
Run JEPA training using a nested config dict or JEPAConfig instance.
This entrypoint initializes distributed context, data pipeline, masking, models, optimizers/schedulers, checkpointing, and the full epoch loop.
- class world_models.training.train_iris.IRISTrainer(game='ALE/Pong-v5', device='cuda', seed=42, config=None)[source]#
Bases:
objectTraining loop for IRIS on Atari 100k benchmark.
- Parameters:
game (str)
device (str)
seed (int)
config (IRISConfig | None)
- preprocess_frame(frame)[source]#
Preprocess frame: resize to 64x64 and normalize.
- Parameters:
frame (ndarray)
- Return type:
ndarray
- collect_experience(num_steps, epsilon=0.01)[source]#
Collect experience from environment.
- Parameters:
num_steps (int) – Number of steps to collect
epsilon (float) – Random action probability
- Returns:
Mean episode return
- Return type:
float
- train_epoch(epoch)[source]#
Train for one epoch.
- Parameters:
epoch (int) – Current epoch number
- Returns:
Dictionary of metrics
- Return type:
dict
- get_epsilon(epoch)[source]#
Get exploration epsilon with decay.
- Parameters:
epoch (int)
- Return type:
float
- evaluate(num_episodes=100, render=False)[source]#
Evaluate agent performance.
- Parameters:
num_episodes (int) – Number of evaluation episodes
render (bool) – If True, also return video frames and per-step latent vectors
- Returns:
dict with evaluation metrics If render is True: tuple (episode_returns_array, videos_list, latents_array)
- Return type:
If render is False (default)
- class world_models.training.train_genie.GenieConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, tokenizer_encoder_depth=12, tokenizer_decoder_depth=20, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_encoder_depth=20, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
objectConfiguration for Genie training.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 512#
- tokenizer_decoder_dim: int = 1024#
- tokenizer_encoder_depth: int = 12#
- tokenizer_decoder_depth: int = 20#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 1024#
- action_encoder_depth: int = 20#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.training.train_genie.VideoDataset(video_paths, num_frames=16, image_size=64)[source]#
Bases:
DatasetDataset for video data.
- Parameters:
video_paths (list)
num_frames (int)
image_size (int)
- class world_models.training.train_genie.GenieTrainer(model, config, device=None)[source]#
Bases:
objectTrainer for Genie model.
- Parameters:
model (Module)
config (GenieConfig)
device (device | None)
- train_step(batch)[source]#
Single training step.
- Parameters:
batch (Tensor) – (B, C, T, H, W) video batch
- Returns:
Dictionary of losses
- Return type:
Dict[str, Tensor]
- validate(val_batch)[source]#
Validation step.
- Parameters:
val_batch (Tensor) – (B, C, T, H, W) validation video batch
- Returns:
Dictionary of validation metrics
- Return type:
Dict[str, Tensor]
- train(train_dataloader, val_dataloader=None, num_steps=None, log_interval=100, val_interval=1000)[source]#
Full training loop.
- Parameters:
train_dataloader (DataLoader) – Training data loader
val_dataloader (DataLoader | None) – Validation data loader (optional)
num_steps (int | None) – Number of training steps (uses config.max_steps if None)
log_interval (int) – Logging frequency
val_interval (int) – Validation frequency
- world_models.training.train_genie.create_genie_trainer(config=None, device=None)[source]#
Factory function to create Genie trainer and model.
- Parameters:
config (GenieConfig | None)
device (device | None)
- Return type:
Tuple[GenieTrainer, Module]
- world_models.training.train_planet.train(memory, rssm, optimizer, device, N=32, H=50, beta=1.0, grads=False)[source]#
Training implementation as indicated in: Learning Latent Dynamics for Planning from Pixels arXiv:1811.04551
- (a.) The Standard Variational Bound Method
using only single step predictions.
- world_models.training.train_planet.main()[source]#
Example PlaNet/RSSM training script with rollout collection and evaluation.
Builds environment/model/policy objects, iteratively trains on replayed episodes, and periodically saves videos and checkpoints.
- world_models.training.train_rssm.train_rssm(memory, model, optimizer, record_grads=True)[source]#
Train an RSSM on replayed trajectories for one optimization phase.
Samples batches from memory, computes reconstruction and KL objectives across rollout steps, and returns aggregated loss metrics.
- world_models.training.train_rssm.evaluate(memory, model, path, eps)[source]#
Run one RSSM reconstruction/prediction evaluation and save visual outputs.
Decodes priors/posteriors for a sampled sequence and writes frame grids for qualitative inspection.
- world_models.training.train_rssm.main()[source]#
Standalone training loop for RSSM with generated replay fallback support.
Initializes environment/policy/memory, trains over episodes, logs metrics, and periodically evaluates and checkpoints the model.
- class world_models.training.train_diamond.DiamondAgent(config)[source]#
Bases:
objectDIAMOND: DIffusion As a Model Of eNvironment Dreams
RL agent trained entirely within a diffusion world model.
- Parameters:
config (DiamondConfig)
- evaluate(num_episodes=1)[source]#
Evaluate the agent.
- Parameters:
num_episodes (int)
- Return type:
float
- save_checkpoint(path=None)[source]#
Save model checkpoint.
- Parameters:
path (str | PathLike | None) – Optional path where to write the checkpoint. If path is None or a bare filename, the file is written into checkpoints/diamond/<filename>. If path contains a directory component or is an absolute/relative path, it is used directly. When path is None, the legacy behavior is preserved and the checkpoint is written to checkpoints/diamond/checkpoint.pt.
- load_checkpoint(path=None)[source]#
Load model checkpoint.
- Parameters:
path (str | None) – Optional path to checkpoint. If None, the default checkpoints/diamond/checkpoint.pt is loaded. If a bare filename is provided, we try checkpoints/diamond/<filename>; if a path with directory components is provided we use it directly.
- world_models.training.train_diamond.train_diamond(game, seed=0, preset=None, device='cpu')[source]#
Train DIAMOND on a specific game.
- Parameters:
game (str)
seed (int)
preset (str | None)
device (str)
- class world_models.training.rl_harness.ActorCritic(obs_shape, action_dim, hidden_dim=256)[source]#
Bases:
ModuleSimple actor-critic network for RL harness.
- Parameters:
obs_shape (tuple)
action_dim (int)
hidden_dim (int)
- class world_models.training.rl_harness.PPOTrainer(vec_env, device='cpu', lr=0.0003, gamma=0.99, gae_lambda=0.95, clip_ratio=0.2, num_epochs=10, batch_size=64, max_grad_norm=0.5, entropy_coeff=0.01, value_coeff=0.5)[source]#
Bases:
objectSimple PPO trainer for testing vectorized environments.
- Parameters:
vec_env (TorchVectorizedEnv)
device (str)
lr (float)
gamma (float)
gae_lambda (float)
clip_ratio (float)
num_epochs (int)
batch_size (int)
max_grad_norm (float)
entropy_coeff (float)
value_coeff (float)
- collect_trajectories(num_steps)[source]#
Collect trajectories using the vectorized environment.
- Parameters:
num_steps (int)
- Return type:
Dict[str, Tensor]
- compute_gae(rewards, values, dones)[source]#
Compute Generalized Advantage Estimation.
- Parameters:
rewards (Tensor)
values (Tensor)
dones (Tensor)
- Return type:
Tensor
Memory, controllers, and inference operators#
- class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#
Bases:
objectFixed-size replay buffer for Dreamer with image observations and transitions.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.
- Key Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores raw uint8 images to save memory
Samples sequences (not single transitions) for temporal modeling
Validates sampled sequences don’t span episode boundaries
- Memory Layout:
observations: (capacity, C, H, W) uint8 images
actions: (capacity, action_dim) float32
rewards: (capacity,) float32
terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)
- Sampling Process:
Random start index (avoiding episode boundaries)
Collect sequence of length seq_len with wraparound
Validate no terminal in middle of sequence
Return batch of sequences
- Usage with Dreamer:
- buffer = ReplayBuffer(
size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch
)
# Add transitions during interaction buffer.add(obs, action, reward, done)
# Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample() # Shapes: (seq_len, batch, C, H, W), (seq_len, batch, action_dim), etc.
- Memory Efficiency:
Uses uint8 for images (1 byte per pixel vs 4 for float32)
Sequences share observations (overlapping windows)
Configurable capacity based on available system memory
Note
The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.
- Parameters:
size (int)
obs_shape (Tuple[int, ...])
action_size (int)
seq_len (int)
batch_size (int)
- add(obs, ac, rew, done)[source]#
Add a transition to the buffer.
- Parameters:
obs (dict) – Observation dict with ‘image’ key containing the observation
ac (ndarray) – Action taken, shape (action_size,)
rew (float) – Reward received, scalar
done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise
- Return type:
None
- class world_models.memory.dreamer_memory.Memory(capacity=10000)[source]#
Bases:
objectSimple deque-based memory for storing transitions.
Used by PlaNet for online planning. Stores recent transitions and provides random sampling for policy updates.
- Parameters:
capacity (int) – Maximum number of transitions to store
- Usage:
memory = Memory(capacity=10000) memory.append(obs, action, reward, done, info) batch = random.sample(memory, batch_size=32)
- class world_models.memory.dreamer_memory.Episode(observation, action=None, reward=None, terminal=None, info=None)[source]#
Bases:
objectStores a single episode for PlaNet’s imagination and planning.
An episode is a sequence of (observation, action, reward) tuples collected during environment interaction. Episodes are used for computing returns and training value functions.
- Parameters:
obs – Initial observation
action – First action (optional)
reward – Initial reward (optional)
info – Additional info dict (optional)
- Usage:
episode = Episode(obs, info=info) episode.append(action, obs, reward, done, info) episodes = [episode for _ in range(num_episodes)]
# Use with Planet agent for planning imag_state, imag_reward, imag_action = planet.imagine(episodes)
- class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]#
Bases:
objectRecords the agent’s interaction with the environment for a single episode.
Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.
- x#
Observations collected during the episode.
- Type:
list or np.ndarray
- u#
Actions taken.
- Type:
list or np.ndarray
- r#
Rewards received.
- Type:
list or np.ndarray
- t#
Terminal flags (0.0 = continue, 1.0 = terminal).
- Type:
list or np.ndarray
- info#
Additional episode metadata.
- Type:
dict
- Parameters:
postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.
Example
>>> episode = Episode() >>> episode.append(obs, action, reward, False) >>> episode.append(obs, action, reward, True) >>> episode.terminate(final_obs) >>> print(episode.x.shape) # Now a numpy array
- property size#
- class world_models.memory.planet_memory.Memory(size=None)[source]#
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.
- Features:
Stores complete episodes as lists of transitions
Samples contiguous sub-sequences for sequence models
Supports time-major formatting (time-first) for RNN input
Memory usage estimation to prevent OOM errors
- Parameters:
size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).
- episodes#
Collection of Episode objects.
- Type:
deque
- eps_lengths#
Length of each episode.
- Type:
deque
- size#
Total number of transitions across all episodes.
- Type:
property
Example
>>> memory = Memory(size=100) >>> memory.append([episode1, episode2]) >>> batch, lengths = memory.sample(batch_size=32, tracelen=50)
- property size#
- sample(batch_size, tracelen=1, time_first=False)[source]#
Sample random sub-sequences from stored episodes.
Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.
- Parameters:
batch_size (int) – Number of sequences to sample.
tracelen (int) – Length of each sequence (default: 1).
time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).
- Returns:
- (observations, actions, rewards, terminals, lengths)
observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)
actions: (batch, tracelen, action_dim) or (tracelen, batch, …)
rewards: (batch, tracelen) or (tracelen, batch)
terminals: (batch, tracelen) or (tracelen, batch)
lengths: (batch,) original episode lengths for each sample
- Return type:
tuple
- Raises:
ValueError – If memory is empty or no episodes meet minimum length.
MemoryError – If estimated memory usage exceeds 200 MiB threshold.
- class world_models.memory.iris_memory.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#
Bases:
objectReplay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.
- Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores uint8 images for memory efficiency
Samples sequences with validation to avoid episode boundaries
Supports sequence sampling for temporal learning
- Memory Layout:
observations: (capacity, C, H, W) uint8
actions: (capacity, action_size) float32
rewards: (capacity,) float32
terminals: (capacity,) float32
- Parameters:
size (int) – Maximum number of transitions to store.
obs_shape (tuple) – Shape of observations as (C, H, W).
action_size (int) – Dimension of actions.
seq_len (int) – Length of sequences to sample (default: 20).
batch_size (int) – Number of sequences per batch (default: 64).
- size#
Buffer capacity.
- Type:
int
- obs_shape#
Observation shape.
- Type:
tuple
- action_size#
Action dimension.
- Type:
int
- seq_len#
Sequence length.
- Type:
int
- batch_size#
Batch size.
- Type:
int
- steps#
Total transitions added.
- Type:
int
- episodes#
Number of episode terminations observed.
- Type:
int
- add(obs, action, reward, terminal)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray) – Observation array with shape (C, H, W).
action (ndarray) – Action array with shape (action_size,).
reward (float) – Scalar reward value.
terminal (bool) – Boolean indicating if episode terminated.
- sample_sequence(seq_len=None)[source]#
Sample a batch of sequences for world model training.
- Returns:
(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)
- Return type:
- Parameters:
seq_len (int | None)
- sample_single()[source]#
Sample a single transition for online updates.
- Return type:
Tuple[ndarray, ndarray, float, float]
- property buffer_capacity#
Returns the total capacity of the buffer.
- class world_models.memory.iris_memory.IRISOnPolicyBuffer(max_steps=1000)[source]#
Bases:
objectOn-policy buffer for collecting trajectories during environment interaction.
Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.
- Useful for:
Collecting complete episode trajectories
Storing data before batch processing
Temporary storage during environment interaction
- Parameters:
max_steps (int) – Maximum number of steps to store (default: 1000).
- max_steps#
Maximum buffer capacity.
- Type:
int
- observations#
List of observations.
- Type:
list
- actions#
List of actions.
- Type:
list
- rewards#
List of rewards.
- Type:
list
- terminals#
List of terminal flags.
- Type:
list
- class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#
Bases:
objectModel-predictive controller using Cross-Entropy Method (CEM) with RSSM.
Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.
- Algorithm:
Initialize Gaussian distribution over action sequences
Sample N candidate action sequences
Rollout each sequence in RSSM latent space
Score by predicted cumulative rewards
Keep top K candidates, fit Gaussian to them
Repeat for T iterations
Execute first action from best sequence
- Why latent space planning?
Images are high-dimensional; latent states are compact
Enables thousands of rollouts in parallel
Dynamics model is more accurate in latent space
- Parameters:
model – RSSM instance for latent dynamics
planning_horizon – Number of future steps to plan (H)
num_candidates – Number of action sequences to sample (N)
num_iterations – CEM refinement iterations (T)
top_candidates – Number of best candidates to keep (K)
device – torch device
- Usage with Planet agent:
- policy = RSSMPolicy(
model=rssm, planning_horizon=12, num_candidates=1000, num_iterations=8, top_candidates=100, device=’cuda’
)
policy.reset() action = policy.poll(observation) # (1, action_dim)
# For continuous control: next_obs, reward, done, info = env.step(action)
- Comparison with Dreamer:
RSSMPolicy: Online planning, chooses actions by optimization at each step
DreamerActor: Train actor network to predict actions from states
Dreamer is more sample-efficient for complex tasks; CEM is more flexible
- class world_models.controller.iris_policy.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleActor network for IRIS (Imagined Rollouts with Implicit Successor) policy.
Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.
- Architecture:
CNN: Extracts features from input frames (3x64x64 -> 512)
LSTM: Processes temporal sequences with configurable layers
Linear: Maps hidden states to action logits
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
- action_size#
Number of discrete actions.
- Type:
int
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- forward(frames, hidden_state=None, burn_in_frames=None)[source]#
Forward pass through actor.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state
burn_in_frames (Tensor | None) – Frames to use for initializing hidden state
- Returns:
Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple
- Return type:
action_logits
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- get_action(frame, temperature=1.0, deterministic=False)[source]#
Get action from a single frame.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
temperature (float) – Softmax temperature (higher = more random)
deterministic (bool) – If True, return argmax; else sample
- Returns:
Selected action indices (B,)
- Return type:
action
- class world_models.controller.iris_policy.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCritic network for IRIS value estimation.
Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.
- Architecture:
CNN: Shared feature extractor with actor (3x64x64 -> 512)
LSTM: Temporal processing with same architecture as actor
Linear: Maps hidden states to scalar values
- Parameters:
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- Returns:
Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.
- Return type:
values
- Parameters:
hidden_size (int)
num_layers (int)
frame_shape (Tuple[int, int, int])
- forward(frames, hidden_state=None)[source]#
Forward pass through critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple
- Returns:
Value estimates (B, T) hidden_state: Updated (h, c) tuple
- Return type:
values
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class world_models.controller.iris_policy.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#
Bases:
ModuleCNN feature extractor shared between actor and critic networks.
Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.
- Architecture:
Conv layers: 32 -> 64 -> 128 -> 256 channels
Each conv has stride=2 for spatial downsampling
Final linear layer projects to desired output dimension
- Parameters:
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
output_size (int) – Size of output feature vector (default: 512).
- frame_shape#
Input frame shape.
- Type:
tuple
- output_size#
Output feature dimension.
- Type:
int
- Returns:
Feature vectors with shape (B, output_size).
- Return type:
features
- Parameters:
frame_shape (Tuple[int, int, int])
output_size (int)
- class world_models.controller.iris_policy.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCombined policy module for IRIS (Imagined Rollouts with Implicit Successor).
Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
Example
>>> policy = IRISPolicy( ... action_size=18, ... hidden_size=512, ... num_layers=4, ... frame_shape=(3, 64, 64) ... ) >>> action = policy.act(frame, temperature=1.0, deterministic=False)
- forward(frames)[source]#
Get action logits from frames.
- Parameters:
frames (Tensor)
- Return type:
Tensor
- act(frame, temperature=1.0, deterministic=False)[source]#
Sample action from policy.
- Parameters:
frame (Tensor)
temperature (float)
deterministic (bool)
- Return type:
Tensor
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- class world_models.controller.rollout_generator.RolloutGenerator(env, device, policy=None, max_episode_steps=None, episode_gen=None, name=None, enable_streaming_video=False, streaming_video_path=None, streaming_video_fps=20, streaming_video_format='mp4')[source]#
Bases:
objectRollout generator class.
- class world_models.inference.operators.OperatorABC[source]
Bases:
ABCAbstract base class for operators that preprocess inputs for inference pipelines.
- abstractmethod process(inputs)[source]
Process raw inputs into standardized tensor format for model consumption.
- Parameters:
inputs (Any) – Raw input data (dict, tensor, or other formats)
- Returns:
Dict of processed tensors ready for model input
- Return type:
Dict[str, Tensor]
- class world_models.inference.operators.DreamerOperator(image_size=64, action_dim=6)[source]
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- process(inputs)[source]
Process Dreamer inputs: image observation and action.
Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- class world_models.inference.operators.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
- process(inputs)[source]
Process JEPA inputs: images with masking.
Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- class world_models.inference.operators.IrisOperator(seq_length=512, vocab_size=32000)[source]
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- process(inputs)[source]
Process Iris inputs: token sequences and optional embeddings.
Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- class world_models.inference.operators.PlaNetOperator(state_dim=32, action_dim=4)[source]
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- process(inputs)[source]
Process PlaNet inputs: state observations and actions.
Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- world_models.inference.operators.get_operator(name, **kwargs)[source]
Factory function to get inference operators by name.
- Parameters:
name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’
**kwargs – Operator-specific configuration
- Returns:
Configured OperatorABC instance
Example
>>> op = get_operator('dreamer', image_size=64, action_dim=6) >>> processed = op.process({'image': image, 'action': action})
- class world_models.inference.operators.base.OperatorABC[source]#
Bases:
ABCAbstract base class for operators that preprocess inputs for inference pipelines.
- class world_models.inference.operators.dreamer_operator.DreamerOperator(image_size=64, action_dim=6)[source]#
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- class world_models.inference.operators.planet_operator.PlaNetOperator(state_dim=32, action_dim=4)[source]#
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- class world_models.inference.operators.iris_operator.IrisOperator(seq_length=512, vocab_size=32000)[source]#
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- class world_models.inference.operators.jepa_operator.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
Datasets, environments, and transforms#
Environment adapters#
The environment APIs below mirror the dedicated environment guide pages: DMC, Gym/Gymnasium, Atari/ALE, MuJoCo, Unity ML-Agents, and vectorization utilities. DIAMOND-style Atari support is intentionally not listed as an environment adapter because it is Atari preprocessing rather than a separate environment family.
- world_models.envs.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#
Create any Atari environment from Arcadic Learning Environment (ALE).
- Parameters:
env_id (str) – The id of the Atari environment to create.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The created Atari environment.
- Return type:
gym.Env
- world_models.envs.list_available_atari_envs()[source]#
Get a list of all available Atari environments in Arcadic Learning Environment (ALE).
- Returns:
List of available Atari environment IDs.
- Return type:
list[str]
- world_models.envs.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#
Create vectorized Atari environments using ALE’s native AtariVectorEnv.
- Parameters:
game (str) – The name of the Atari game (e.g., “pong”, “breakout”).
num_envs (int) – Number of parallel environments.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
seed (Optional[int]) – Random seed for reproducibility.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The vectorized Atari environment.
- Return type:
AtariVectorEnv
- class world_models.envs.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#
Bases:
objectNative MuJoCo environment adapter for pixel-based world-model training.
The adapter uses the low-level
mujocoPython package directly: models are compiled from MJCF XML strings/files or MJB binaries viamujoco.MjModel; simulation state lives inmujoco.MjData; actions are written todata.ctrl; and images are produced withmujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract:{"image": uint8[C, H, W]}.Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply
reward_fnandterminal_fncallbacks. By default, rewards are0.0and episodes terminate only through external wrappers such asTimeLimit.- Parameters:
xml_path (str | Path | None)
xml_string (str | None)
binary_path (str | Path | None)
assets (dict[str, bytes] | None)
seed (int)
size (tuple[int, int])
camera (str | int | None)
reward_fn (RewardFn | None)
terminal_fn (TerminalFn | None)
frame_skip (int)
reset_noise_scale (float)
default_control_range (tuple[float, float])
- property observation_space#
- property action_space#
- world_models.envs.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create one MuJoCo image environment factory for tasks and MJCF/MJB models.
- Parameters:
model (str | Path | None) – Either a Gymnasium MuJoCo task id such as
"Humanoid-v4", an MJCF XML path/string, or an MJB binary path.backend (str) –
"auto"infers native vs Gymnasium task mode. Use"native"for MJCF/MJB,"gymnasium"for task ids, or"robotics"for Gymnasium Robotics registrations.seed (int) – Seed forwarded to the image wrapper.
size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.
gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.makein task-id mode. Extra**kwargsare also forwarded there.**kwargs – Native
MuJoCoImageEnvoptions for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.
- Returns:
A TorchWM image environment returning
{"image": uint8[C, H, W]}.
- world_models.envs.make_mujoco_env_from_config(args, size)[source]#
Build a MuJoCo image environment from a DreamerConfig-like object.
- Parameters:
size (tuple[int, int])
- world_models.envs.list_gymnasium_robotics_envs()[source]#
List all Gymnasium Robotics ids registered by the installed package.
Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.
- Return type:
list[str]
- world_models.envs.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create a TorchWM image wrapper for a Gymnasium Robotics environment.
- Parameters:
env (str) – Any environment id registered by
gymnasium-robotics.seed (int) – Seed forwarded to
GymImageEnv.size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode forwarded to
gymnasium.make.gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.make.**kwargs – Additional keyword arguments forwarded to
gymnasium.make.
- Returns:
A
GymImageEnvthat emits{"image": uint8[C, H, W]}observations.
- world_models.envs.register_gymnasium_robotics_envs()[source]#
Import Gymnasium Robotics so its environments are registered with Gymnasium.
Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external
gymnasium-roboticspackage. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely ongymnasium.register_envs; this helper supports both paths.
- class world_models.envs.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#
Bases:
objectGym-like environment wrapper that always returns image observations.
This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.
- Features:
Supports environment IDs (string) and pre-built environment objects.
Synthesizes RGB images from vector observations for pixel-based training.
Exposes continuous action spaces mapped to [-1, 1] range.
Converts discrete actions to one-hot vectors.
Returns observations as dict {“image”: (C, H, W)} with uint8 values.
- Parameters:
env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
seed (int) – Random seed for environment reset (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
render_mode (str) – Render mode for environment (default: “rgb_array”).
- observation_space#
Dict space with “image” key containing (C, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode (default: 1000).
- property observation_space#
- property action_space#
- property max_episode_steps#
- world_models.envs.make_gym_env(env, **kwargs)[source]#
Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.
- Parameters:
env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
**kwargs – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)
- Returns:
- A wrapper that always returns image observations in the
format {“image”: (C, H, W)} suitable for pixel-based world models.
- Return type:
- class world_models.envs.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#
Bases:
objectGym-like wrapper for Unity ML-Agents environments.
Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.
- Features:
Supports single-agent control with continuous action spaces.
Returns observations as {“image”: (C, H, W)} with uint8 values.
Normalizes actions to [-1, 1] range.
Includes rendered frames in observations for visual policies.
- Parameters:
file_name (str) – Path to the Unity environment binary.
behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.
seed (int) – Random seed for environment (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
worker_id (int) – Worker ID for multi-environment setup (default: 0).
base_port (int) – Base port for Unity environment communication (default: 5005).
no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).
time_scale (float) – Simulation time scale multiplier (default: 20.0).
quality_level (int) – Graphics quality level 0-5 (default: 1).
max_episode_steps (int) – Maximum steps per episode (default: 1000).
- observation_space#
Dict space with “image” key containing (3, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode.
- Raises:
ValueError – If no behaviors found or action space is not continuous.
RuntimeError – If no agents available after reset.
- property observation_space#
- property action_space#
- property max_episode_steps#
- world_models.envs.make_unity_mlagents_env(env_id=None, **kwargs)[source]#
Create a Unity ML-Agents environment wrapper.
Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.
- Parameters:
**kwargs – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).
env_id (str | None)
- Returns:
A Gym-compatible wrapper for Unity environments.
- Return type:
- class world_models.envs.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#
Bases:
objectGym-style adapter for DeepMind Control Suite tasks.
The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.
- Features:
Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)
Automatically handles special cases like “cup” -> “ball_in_cup”
Renders RGB images at configurable resolution
Returns observations as dict with both state vectors and images
- Parameters:
name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).
seed (int) – Random seed for environment initialization.
size (tuple) – Target image size as (height, width) (default: (64, 64)).
camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.
- observation_space#
Dict space with state keys and “image”.
- Type:
gym.spaces.Dict
- action_space#
Continuous action space from DMC spec.
- Type:
gym.spaces.Box
Example
>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64)) >>> obs = env.reset() >>> print(obs.keys()) # dict_keys(['position', 'velocity', 'image'])
- property observation_space#
- property action_space#
- class world_models.envs.BraxImageEnv(env, seed=0, size=(64, 64), backend=None, episode_length=None, auto_reset=False, jit=True, suppress_warp_warnings=True, **env_kwargs)[source]#
Bases:
objectGym-like adapter for training TorchWM world models on Brax tasks.
Brax environments are functional JAX environments:
resetconsumes a PRNG key and returns a state, whilestepconsumes the previous state plus an action and returns the next state. This adapter stores the Brax state between calls and converts state observations into image observations compatible with pixel-based TorchWM agents such as Dreamer.If a Brax renderer is not available, vector observations are rendered as deterministic feature-band images so training code can still consume a pixel stream. The original vector observation is also exposed through
info["vector_observation"]afterstepfor diagnostics.- Parameters:
seed (int)
size (tuple[int, int])
backend (str | None)
episode_length (int | None)
auto_reset (bool)
jit (bool)
suppress_warp_warnings (bool)
- property observation_space#
- property action_space#
- property max_episode_steps#
- world_models.envs.make_brax_env(env, **kwargs)[source]#
Create a TorchWM image wrapper for Brax environments.
- Parameters:
env – Brax environment name (for example,
"ant") or a pre-built Brax environment object exposingreset(rng)andstep(state, action).**kwargs – Additional keyword arguments passed to
BraxImageEnv.
- Returns:
A Gym-like wrapper that returns
{"image": (C, H, W)}observations and exposes continuous actions in the Brax[-1, 1]range.- Return type:
- class world_models.envs.TimeLimit(env, duration)[source]#
Bases:
objectTerminate episodes after a fixed number of wrapper steps.
If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.
- class world_models.envs.ActionRepeat(env, amount)[source]#
Bases:
objectRepeat each action for a fixed number of environment steps.
Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.
- class world_models.envs.NormalizeActions(env)[source]#
Bases:
objectExpose a normalized [-1, 1] action space for bounded continuous controls.
Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.
- property action_space#
- class world_models.envs.ObsDict(env, key='obs')[source]#
Bases:
objectConvert scalar/array observations into a dictionary observation format.
This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).
- property observation_space#
- property action_space#
- class world_models.envs.OneHotAction(env)[source]#
Bases:
objectWrap discrete-action environments to accept one-hot action vectors.
The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.
- property action_space#
- class world_models.envs.RewardObs(env)[source]#
Bases:
objectAugment observations with the latest scalar reward under obs[“reward”].
Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.
- property observation_space#
- class world_models.envs.ResizeImage(env, size=(64, 64))[source]#
Bases:
objectResize image-like observation entries to a target spatial size.
The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.
- property obs_space#
- class world_models.envs.RenderImage(env, key='image')[source]#
Bases:
objectInject RGB renders from env.render(“rgb_array”) into observations.
This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.
- property obs_space#
- class world_models.envs.SelectAction(env, key)[source]#
Bases:
WrapperGym wrapper for dictionary actions that forwards a selected key only.
This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.
- world_models.envs.make_env(env_id, **kwargs)[source]#
Compatibility helper: create an environment by delegating to package-specific factories when available, falling back to gym.make.
This preserves older callers that expect make_env to exist.
- Parameters:
env_id (str)
- class world_models.envs.dmc.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#
Bases:
objectGym-style adapter for DeepMind Control Suite tasks.
The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.
- Features:
Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)
Automatically handles special cases like “cup” -> “ball_in_cup”
Renders RGB images at configurable resolution
Returns observations as dict with both state vectors and images
- Parameters:
name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).
seed (int) – Random seed for environment initialization.
size (tuple) – Target image size as (height, width) (default: (64, 64)).
camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.
- observation_space#
Dict space with state keys and “image”.
- Type:
gym.spaces.Dict
- action_space#
Continuous action space from DMC spec.
- Type:
gym.spaces.Box
Example
>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64)) >>> obs = env.reset() >>> print(obs.keys()) # dict_keys(['position', 'velocity', 'image'])
- property observation_space#
- property action_space#
- world_models.envs.gym_env.make_gym_env(env, **kwargs)[source]#
Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.
- Parameters:
env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
**kwargs – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)
- Returns:
- A wrapper that always returns image observations in the
format {“image”: (C, H, W)} suitable for pixel-based world models.
- Return type:
- class world_models.envs.gym_env.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#
Bases:
objectGym-like environment wrapper that always returns image observations.
This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.
- Features:
Supports environment IDs (string) and pre-built environment objects.
Synthesizes RGB images from vector observations for pixel-based training.
Exposes continuous action spaces mapped to [-1, 1] range.
Converts discrete actions to one-hot vectors.
Returns observations as dict {“image”: (C, H, W)} with uint8 values.
- Parameters:
env – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
seed (int) – Random seed for environment reset (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
render_mode (str) – Render mode for environment (default: “rgb_array”).
- observation_space#
Dict space with “image” key containing (C, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode (default: 1000).
- property observation_space#
- property action_space#
- property max_episode_steps#
- world_models.envs.ale_atari_env.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#
Create any Atari environment from Arcadic Learning Environment (ALE).
- Parameters:
env_id (str) – The id of the Atari environment to create.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The created Atari environment.
- Return type:
gym.Env
- world_models.envs.ale_atari_env.list_available_atari_envs()[source]#
Get a list of all available Atari environments in Arcadic Learning Environment (ALE).
- Returns:
List of available Atari environment IDs.
- Return type:
list[str]
- world_models.envs.ale_atari_vector_env.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#
Create vectorized Atari environments using ALE’s native AtariVectorEnv.
- Parameters:
game (str) – The name of the Atari game (e.g., “pong”, “breakout”).
num_envs (int) – Number of parallel environments.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
seed (Optional[int]) – Random seed for reproducibility.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The vectorized Atari environment.
- Return type:
AtariVectorEnv
- world_models.envs.mujoco_env.make_mujoco_env_from_config(args, size)[source]#
Build a MuJoCo image environment from a DreamerConfig-like object.
- Parameters:
size (tuple[int, int])
- class world_models.envs.mujoco_env.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#
Bases:
objectNative MuJoCo environment adapter for pixel-based world-model training.
The adapter uses the low-level
mujocoPython package directly: models are compiled from MJCF XML strings/files or MJB binaries viamujoco.MjModel; simulation state lives inmujoco.MjData; actions are written todata.ctrl; and images are produced withmujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract:{"image": uint8[C, H, W]}.Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply
reward_fnandterminal_fncallbacks. By default, rewards are0.0and episodes terminate only through external wrappers such asTimeLimit.- Parameters:
xml_path (str | Path | None)
xml_string (str | None)
binary_path (str | Path | None)
assets (dict[str, bytes] | None)
seed (int)
size (tuple[int, int])
camera (str | int | None)
reward_fn (RewardFn | None)
terminal_fn (TerminalFn | None)
frame_skip (int)
reset_noise_scale (float)
default_control_range (tuple[float, float])
- property observation_space#
- property action_space#
- world_models.envs.mujoco_env.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create one MuJoCo image environment factory for tasks and MJCF/MJB models.
- Parameters:
model (str | Path | None) – Either a Gymnasium MuJoCo task id such as
"Humanoid-v4", an MJCF XML path/string, or an MJB binary path.backend (str) –
"auto"infers native vs Gymnasium task mode. Use"native"for MJCF/MJB,"gymnasium"for task ids, or"robotics"for Gymnasium Robotics registrations.seed (int) – Seed forwarded to the image wrapper.
size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.
gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.makein task-id mode. Extra**kwargsare also forwarded there.**kwargs – Native
MuJoCoImageEnvoptions for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.
- Returns:
A TorchWM image environment returning
{"image": uint8[C, H, W]}.
- world_models.envs.robotics_env.is_moved_mujoco_error(exc)[source]#
Return whether Gymnasium reported the v2/v3 MuJoCo move.
- Parameters:
exc (BaseException)
- Return type:
bool
- world_models.envs.robotics_env.register_gymnasium_robotics_envs()[source]#
Import Gymnasium Robotics so its environments are registered with Gymnasium.
Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external
gymnasium-roboticspackage. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely ongymnasium.register_envs; this helper supports both paths.
- world_models.envs.robotics_env.list_gymnasium_robotics_envs()[source]#
List all Gymnasium Robotics ids registered by the installed package.
Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.
- Return type:
list[str]
- world_models.envs.robotics_env.make_gymnasium_env_with_robotics_fallback(env, *, render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create a Gymnasium env and retry after Robotics registration if needed.
- Parameters:
env (str)
render_mode (str)
gym_kwargs (dict[str, Any] | None)
- world_models.envs.robotics_env.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create a TorchWM image wrapper for a Gymnasium Robotics environment.
- Parameters:
env (str) – Any environment id registered by
gymnasium-robotics.seed (int) – Seed forwarded to
GymImageEnv.size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode forwarded to
gymnasium.make.gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.make.**kwargs – Additional keyword arguments forwarded to
gymnasium.make.
- Returns:
A
GymImageEnvthat emits{"image": uint8[C, H, W]}observations.
- world_models.envs.unity_env.make_unity_mlagents_env(env_id=None, **kwargs)[source]#
Create a Unity ML-Agents environment wrapper.
Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.
- Parameters:
**kwargs – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).
env_id (str | None)
- Returns:
A Gym-compatible wrapper for Unity environments.
- Return type:
- class world_models.envs.unity_env.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#
Bases:
objectGym-like wrapper for Unity ML-Agents environments.
Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.
- Features:
Supports single-agent control with continuous action spaces.
Returns observations as {“image”: (C, H, W)} with uint8 values.
Normalizes actions to [-1, 1] range.
Includes rendered frames in observations for visual policies.
- Parameters:
file_name (str) – Path to the Unity environment binary.
behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.
seed (int) – Random seed for environment (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
worker_id (int) – Worker ID for multi-environment setup (default: 0).
base_port (int) – Base port for Unity environment communication (default: 5005).
no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).
time_scale (float) – Simulation time scale multiplier (default: 20.0).
quality_level (int) – Graphics quality level 0-5 (default: 1).
max_episode_steps (int) – Maximum steps per episode (default: 1000).
- observation_space#
Dict space with “image” key containing (3, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode.
- Raises:
ValueError – If no behaviors found or action space is not continuous.
RuntimeError – If no agents available after reset.
- property observation_space#
- property action_space#
- property max_episode_steps#
- class world_models.envs.vector_env.SimWorker(worker_id, env_factory, num_envs, command_queue, result_queue, seed=None)[source]#
Bases:
ProcessWorker process that manages a batch of environment instances. Handles batched stepping for parallel rollouts.
- Parameters:
worker_id (int)
env_factory (Callable)
num_envs (int)
command_queue (Queue)
result_queue (Queue)
seed (Optional[int])
- class world_models.envs.vector_env.VectorizedEnv(env_factory, num_workers=2, envs_per_worker=4, seed=None)[source]#
Bases:
ABCAbstract base class for vectorized environments. Manages multiple worker processes for parallel simulation.
- Parameters:
env_factory (Callable)
num_workers (int)
envs_per_worker (int)
seed (Optional[int])
- class world_models.envs.vector_env.TorchVectorizedEnv(*args, **kwargs)[source]#
Bases:
VectorizedEnvTorchWM-compatible vectorized environment. Returns batched tensors suitable for PyTorch training.
- class world_models.envs.wrappers.TimeLimit(env, duration)[source]#
Bases:
objectTerminate episodes after a fixed number of wrapper steps.
If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.
- class world_models.envs.wrappers.ActionRepeat(env, amount)[source]#
Bases:
objectRepeat each action for a fixed number of environment steps.
Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.
- class world_models.envs.wrappers.NormalizeActions(env)[source]#
Bases:
objectExpose a normalized [-1, 1] action space for bounded continuous controls.
Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.
- property action_space#
- class world_models.envs.wrappers.ObsDict(env, key='obs')[source]#
Bases:
objectConvert scalar/array observations into a dictionary observation format.
This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).
- property observation_space#
- property action_space#
- class world_models.envs.wrappers.OneHotAction(env)[source]#
Bases:
objectWrap discrete-action environments to accept one-hot action vectors.
The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.
- property action_space#
- class world_models.envs.wrappers.RewardObs(env)[source]#
Bases:
objectAugment observations with the latest scalar reward under obs[“reward”].
Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.
- property observation_space#
- class world_models.envs.wrappers.ResizeImage(env, size=(64, 64))[source]#
Bases:
objectResize image-like observation entries to a target spatial size.
The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.
- property obs_space#
- class world_models.envs.wrappers.RenderImage(env, key='image')[source]#
Bases:
objectInject RGB renders from env.render(“rgb_array”) into observations.
This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.
- property obs_space#
- class world_models.envs.wrappers.UUID(env)[source]#
Bases:
WrapperGym wrapper that tracks a unique run identifier per environment reset.
The ID combines timestamp and UUID and can be used to tag episodes or artifacts generated during data collection.
Atari preprocessing helpers#
These helpers wrap Atari environments for specific training recipes. They are not separate environment families.
- class world_models.envs.diamond_atari.DiamondAtariWrapper(env, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64))[source]#
Bases:
WrapperAtari wrapper for DIAMOND following the paper specifications: - frameskip: number of frames to skip (default 4) - max_noop: maximum number of noop actions at reset (default 30) - terminate_on_life_loss: terminate episode when life is lost (default True) - reward_clip: clip rewards to [-1, 0, 1] (default True) - resize: resize observations to specified size (default 64x64)
- Parameters:
env (Env)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (bool)
resize (Tuple[int, int] | None)
- world_models.envs.diamond_atari.make_diamond_atari_env(game, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64), seed=None)[source]#
Create a DIAMOND-compatible Atari environment.
- Parameters:
game (str) – Atari game name (e.g., “Breakout-v5”)
frameskip (int) – Number of frames to skip between actions
max_noop (int) – Maximum number of noop actions at reset
terminate_on_life_loss (bool) – Whether to terminate on life loss
reward_clip (bool) – Whether to clip rewards to [-1, 0, 1]
resize (Tuple[int, int]) – Target size for observations
seed (int | None) – Random seed
- Returns:
Configured Atari environment
- Return type:
Datasets and transforms#
- class world_models.datasets.video_datasets.DatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True)[source]#
Bases:
objectBase configuration for datasets.
- Parameters:
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
pin_memory (bool)
shuffle (bool)
- num_frames: int = 16#
- image_size: int = 64#
- batch_size: int = 4#
- num_workers: int = 4#
- pin_memory: bool = True#
- shuffle: bool = True#
- class world_models.datasets.video_datasets.VideoDatasetBase(data_source, num_frames=16, image_size=64, transform=None, normalize=True)[source]#
Bases:
DatasetBase class for video datasets.
All video datasets should inherit from this class and implement the _load_video method.
- Parameters:
data_source (str | Path | List[str] | List[Path])
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
- data_source: str | Path | List[str] | List[Path]#
- video_paths: Sequence[Path | int]#
- class world_models.datasets.video_datasets.VideoFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.mp4', '.avi', '.mkv', '.webm', '.mov'), recursive=True)[source]#
Bases:
VideoDatasetBaseDataset that loads videos from a folder.
Supports common video formats: .mp4, .avi, .mkv, .webm
- Usage:
- dataset = VideoFolderDataset(
data_source=”/path/to/videos”, num_frames=16, image_size=64
)
- Parameters:
data_source (str | Path | List[str] | List[Path])
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
extensions (Tuple[str, ...])
recursive (bool)
- class world_models.datasets.video_datasets.ImageFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.jpg', '.jpeg', '.png', '.bmp'), image_sort_key=None)[source]#
Bases:
VideoDatasetBaseDataset that loads image sequences from folders.
Each subfolder is treated as a video sequence.
- Usage:
- dataset = ImageFolderDataset(
data_source=”/path/to/images”, num_frames=16, image_size=64
)
- Parameters:
data_source (str | Path | List[str] | List[Path])
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
extensions (Tuple[str, ...])
image_sort_key (Callable | None)
- class world_models.datasets.video_datasets.NumPyDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key=None)[source]#
Bases:
VideoDatasetBaseDataset that loads videos from numpy files.
Supports .npy and .npz files.
- Usage:
- dataset = NumPyDataset(
data_source=”/path/to/videos.npy”, num_frames=16, image_size=64
)
- Parameters:
data_source (str | Path)
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
key (str | None)
- class world_models.datasets.video_datasets.RLEnvironmentDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, obs_key='observations')[source]#
Bases:
VideoDatasetBaseDataset for RL environment recordings.
Loads trajectories stored as: - .npz files with ‘observations’ and ‘actions’ keys - Directory with episode folders
- Usage:
- dataset = RLEnvironmentDataset(
data_source=”/path/to/rl_episodes”, num_frames=16, image_size=64
)
- Parameters:
data_source (str | Path)
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
obs_key (str)
- class world_models.datasets.video_datasets.HDF5Dataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key='videos', memmap=False)[source]#
Bases:
VideoDatasetBaseDataset that loads videos from HDF5 files.
Supports pre-processed video datasets stored in HDF5 format. Expected structure: HDF5 file with ‘videos’ dataset of shape (N, T, H, W, C) or (N, T, C, H, W).
- Usage:
- dataset = HDF5Dataset(
data_source=”/path/to/videos.h5”, num_frames=16, image_size=64
)
- Parameters:
data_source (str | Path)
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
key (str)
memmap (bool)
- world_models.datasets.video_datasets.create_video_dataloader(dataset_type, data_source, num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, pin_memory=True, **kwargs)[source]#
Factory function to create video dataloaders.
- Parameters:
dataset_type (str) – Type of dataset (“video_folder”, “image_folder”, “numpy”, “rl”)
data_source (str | Path | List[str]) – Path or list of paths to data
num_frames (int) – Number of frames per video
image_size (int) – Target image size (height and width)
batch_size (int) – Batch size for dataloader
num_workers (int) – Number of workers for data loading
shuffle (bool) – Whether to shuffle data
pin_memory (bool) – Whether to pin memory for faster GPU transfer
**kwargs – Additional arguments for specific dataset types
- Returns:
Tuple of (dataset, dataloader)
- Return type:
Tuple[Dataset, DataLoader]
- Usage:
- dataset, loader = create_video_dataloader(
dataset_type=”video_folder”, data_source=”/path/to/videos”, num_frames=16, image_size=64, batch_size=4
)
- class world_models.datasets.video_datasets.VideoDatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True, dataset_type='video_folder', data_source='', extensions=('.mp4', '.avi', '.mkv'), recursive=True, obs_key='observations')[source]#
Bases:
DatasetConfigConfiguration for video datasets.
- Parameters:
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
pin_memory (bool)
shuffle (bool)
dataset_type (str)
data_source (str)
extensions (Tuple[str, ...])
recursive (bool)
obs_key (str)
- dataset_type: str = 'video_folder'#
- data_source: str = ''#
- extensions: Tuple[str, ...] = ('.mp4', '.avi', '.mkv')#
- recursive: bool = True#
- obs_key: str = 'observations'#
- world_models.datasets.video_datasets.create_video_dataset_from_config(config)[source]#
Create video dataset and dataloader from config.
- Parameters:
config (VideoDatasetConfig)
- Return type:
Tuple[Dataset, DataLoader]
TinyWorlds Dataset Loaders
Loads pre-processed video datasets from HuggingFace for training Genie-style world models. Based on: AlmondGod/tinyworlds
Available datasets: - PICO_DOOM: Minimal Doom gameplay - PONG: Classic Pong - ZELDA: Zelda Ocarina of Time (2D) - POLE_POSITION: Racing game - SONIC: Sonic the Hedgehog
- class world_models.datasets.tinyworlds.TinyWorldsConfig(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, cache_dir=None, split='train')[source]#
Bases:
objectConfiguration for TinyWorlds datasets.
- Parameters:
dataset_name (str)
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
cache_dir (str | None)
split (str)
- dataset_name: str = 'SONIC'#
- num_frames: int = 16#
- image_size: int = 64#
- batch_size: int = 4#
- num_workers: int = 4#
- cache_dir: str | None = None#
- split: str = 'train'#
- class world_models.datasets.tinyworlds.TinyWorldsDataset(dataset_name='SONIC', num_frames=16, image_size=64, split='train', cache_dir=None, download=True, data_file=None)[source]#
Bases:
DatasetDataset for TinyWorlds game video data.
Loads pre-processed frames from HuggingFace datasets repository.
- Parameters:
dataset_name (str)
num_frames (int)
image_size (int)
split (str)
cache_dir (str | None)
download (bool)
data_file (str | None)
- DATASET_CONFIGS = {'PICO_DOOM': {'description': 'Minimal Doom gameplay', 'filename': 'picodoom_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'POLE_POSITION': {'description': 'Racing game', 'filename': 'pole_position_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'PONG': {'description': 'Classic Pong', 'filename': 'pong_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'SONIC': {'description': 'Sonic the Hedgehog', 'filename': 'sonic_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'ZELDA': {'description': 'Zelda Ocarina of Time (2D)', 'filename': 'zelda_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}}#
- class world_models.datasets.tinyworlds.TinyWorldsDataLoader[source]#
Bases:
objectFactory class for creating TinyWorlds dataloaders.
- DATASET_NAMES = ['PICO_DOOM', 'PONG', 'ZELDA', 'POLE_POSITION', 'SONIC']#
- static create_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#
Create a dataloader for TinyWorlds dataset.
- Parameters:
dataset_name (str) – Name of the game dataset (PICO_DOOM, PONG, ZELDA, POLE_POSITION, SONIC)
num_frames (int) – Number of frames per video sequence
image_size (int) – Target image size (will resize frames)
batch_size (int) – Batch size
num_workers (int) – Number of data loading workers
shuffle (bool) – Whether to shuffle the data
cache_dir (str | None) – Directory to cache downloaded datasets
download (bool) – Whether to download if not cached
data_file (str | None)
- Returns:
Tuple of (dataset, dataloader)
- Return type:
Tuple[TinyWorldsDataset, DataLoader]
- Usage:
- dataset, loader = TinyWorldsDataLoader.create_dataloader(
dataset_name=”SONIC”, num_frames=16, image_size=64, batch_size=4
)
- world_models.datasets.tinyworlds.create_tinyworlds_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#
Factory function to create TinyWorlds dataloaders.
- Parameters:
dataset_name (str) – Name of the game dataset (PICO_DOOM, PONG, ZELDA, POLE_POSITION, SONIC)
num_frames (int) – Number of frames per video sequence
image_size (int) – Target image size
batch_size (int) – Batch size
num_workers (int) – Number of data loading workers
shuffle (bool) – Whether to shuffle
cache_dir (str | None) – Cache directory for datasets
download (bool) – Download if not cached
data_file (str | None)
- Returns:
Tuple of (dataset, dataloader)
- Return type:
Tuple[TinyWorldsDataset, DataLoader]
- Usage:
- dataset, loader = create_tinyworlds_dataloader(
dataset_name=”SONIC”, num_frames=16, batch_size=4
)
- for batch in loader:
# batch shape: (B, T, C, H, W) …
- world_models.datasets.tinyworlds.download_all_datasets(cache_dir=None)[source]#
Download all available TinyWorlds datasets.
- Parameters:
cache_dir (str | None) – Directory to cache downloaded datasets
- Returns:
Dictionary mapping dataset names to local file paths
- Return type:
Dict[str, str | None]
- class world_models.datasets.diamond_dataset.ReplayBuffer(capacity=1000, obs_shape=(64, 64, 3), action_dim=1, device='cpu')[source]#
Bases:
objectReplay buffer for storing environment interactions. Stores (observation, action, reward, done, next_observation) tuples.
- Parameters:
capacity (int)
obs_shape (Tuple[int, int, int])
action_dim (int)
device (str)
- add(obs, action, reward, done, next_obs)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray)
action (int)
reward (float)
done (bool)
next_obs (ndarray)
- sample(batch_size)[source]#
Sample a random batch of transitions.
- Parameters:
batch_size (int)
- Return type:
Dict[str, Tensor]
- sample_sequence(batch_size, sequence_length, burn_in=0)[source]#
Sample a sequence of transitions for training.
- Parameters:
batch_size (int) – Number of sequences to sample
sequence_length (int) – Total sequence length (burn_in + horizon)
burn_in (int) – Number of initial frames to use for conditioning
- Returns:
Dictionary with tensors of shape (batch_size, sequence_length, …)
- Return type:
Dict[str, Tensor]
- is_ready(min_size)[source]#
Check if buffer has enough samples.
- Parameters:
min_size (int)
- Return type:
bool
- class world_models.datasets.diamond_dataset.SequenceDataset(replay_buffer, sequence_length=5, burn_in=4)[source]#
Bases:
DatasetPyTorch Dataset for sampling sequences from the replay buffer. Used for training the diffusion world model.
- Parameters:
replay_buffer (ReplayBuffer)
sequence_length (int)
burn_in (int)
- world_models.datasets.diamond_dataset.collate_fn(batch)[source]#
Collate function for the dataloader.
- Parameters:
batch (List[Dict[str, Tensor]])
- Return type:
Dict[str, Tensor]
- world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]#
Create CIFAR-10 dataset and distributed dataloader.
Factory function that creates a CIFAR-10 dataset with the provided transforms and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or diffusion training pipelines.
- Parameters:
transform – Transforms to apply to images (e.g., RandomCrop, ColorJitter).
batch_size (int) – Number of samples per batch.
collator (callable, optional) – Custom collate function for batching (e.g., mask collator for JEPA).
pin_mem (bool) – Whether to pin memory for faster GPU transfer (default: True).
num_workers (int) – Number of data loading workers (default: 8).
world_size (int) – Number of distributed processes (default: 1).
rank (int) – Rank of current process in distributed setting (default: 0).
root_path (str, optional) – Path to store/load CIFAR-10 data.
drop_last (bool) – Whether to drop incomplete final batch (default: True).
train (bool) – Whether to load train or test split (default: True).
download (bool) – Whether to download dataset if not present (default: False).
- Returns:
- (dataset, dataloader, sampler)
dataset: torchvision.datasets.CIFAR10 instance
dataloader: torch.utils.data.DataLoader with distributed sampling
sampler: torch.utils.data.distributed.DistributedSampler
- Return type:
tuple
Example
>>> transform = make_transforms(crop_size=224) >>> dataset, loader, sampler = make_cifar10( ... transform=transform, ... batch_size=256, ... root_path="./data", ... download=True ... )
- world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]#
Build an ImageNet-1K dataset and dataloader with distributed sampling.
Factory function that creates an ImageNet dataset and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or other self-supervised training pipelines.
- Supports:
Optional data staging from network storage to local scratch
Subset filtering via text file listing allowed image IDs
Distributed sampling for multi-GPU training
- Parameters:
transform (Any) – Transforms to apply to images.
batch_size (int) – Number of samples per batch.
collator (callable, optional) – Custom collate function (e.g., mask collator).
pin_mem (bool) – Whether to pin memory for GPU transfer (default: True).
num_workers (int) – Number of data loading workers (default: 8).
world_size (int) – Number of distributed processes (default: 1).
rank (int) – Rank of current process (default: 0).
root_path (str, optional) – Root path containing ImageNet data.
image_folder (str, optional) – Subfolder containing ImageNet data.
training (bool) – Load train or validation split (default: True).
copy_data (bool) – Copy data locally for faster loading (default: False).
drop_last (bool) – Drop incomplete final batch (default: True).
subset_file (str, optional) – Path to file listing allowed image IDs.
- Returns:
- (dataset, dataloader, sampler)
dataset: ImageNet dataset instance
dataloader: DataLoader with distributed sampling
sampler: DistributedSampler instance
- Return type:
tuple
- class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]#
Bases:
ImageFolderImageNet dataset wrapper with optional local copy/extract workflow.
Extends torchvision.datasets.ImageFolder to support data staging from network storage to local scratch space for faster multi-process training on cluster environments (e.g., SLURM).
- Features:
Optional data copying from network storage to local /scratch
Extracts tar archives automatically on first access
Supports train/validation splits
Optional target indexing for balanced sampling
- class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]#
Bases:
objectView over an ImageNet dataset filtered by an explicit image-id list.
The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.
- property classes#
- world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]#
Copy and extract ImageNet archives to per-job local scratch storage.
In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.
- world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]#
Create an ImageFolder dataset loader for custom folder-structured datasets.
Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.
- Parameters:
transform (Any)
batch_size (int)
collator (Any)
pin_mem (bool)
num_workers (int)
world_size (int)
rank (int)
root_path (str | None)
image_folder (str | None)
drop_last (bool)
val_split (float | None)
- Return type:
Tuple[Dataset, DataLoader, DistributedSampler]
- world_models.transforms.transforms.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]#
Compose image augmentations and normalization for vision model training.
Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.
Masking and JEPA helpers#
Masks sub-module - Masking strategies for JEPA and masked training.
This package provides various masking collator classes for generating encoder/predictor masks during masked representation learning.
- Usage:
from world_models.masks import MaskCollator, DefaultCollator collator = MaskCollator(input_size=(64, 64), patch_size=8)
- world_models.masks.MultiblockMaskCollator#
alias of
MaskCollator
- world_models.masks.RandomMaskCollator#
alias of
MaskCollator
- class world_models.masks.DefaultCollator[source]#
Bases:
objectSimple collator that returns batch data and no masking metadata.
This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.
- class world_models.masks.default.DefaultCollator[source]#
Bases:
objectSimple collator that returns batch data and no masking metadata.
This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.
- class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]#
Bases:
objectGenerate multi-block encoder and predictor masks for JEPA training.
For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.
- class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]#
Bases:
objectGenerate random context/prediction patch splits for masked training.
A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.
- world_models.helpers.jepa_helper.load_checkpoint(device, r_path, encoder, predictor, target_encoder, opt, scaler)[source]#
Load JEPA training state from disk into model and optimizer objects.
Restores encoder, predictor, optional target encoder, optimizer state, and optional AMP scaler, returning the resumed epoch for training restart.
- world_models.helpers.jepa_helper.init_model(device, patch_size=16, model_name='vit_base', crop_size=224, pred_depth=6, pred_emb_dim=384)[source]#
Initialize JEPA encoder and predictor modules with ViT backbones.
Applies truncated-normal parameter initialization, moves modules to the requested device, and returns (encoder, predictor).
- world_models.helpers.jepa_helper.init_opt(encoder, predictor, iterations_per_epoch, start_lr, ref_lr, warmup, num_epochs, wd=1e-06, final_wd=1e-06, final_lr=0.0, use_bfloat16=False, ipe_scale=1.25)[source]#
Build optimizer, AMP scaler, LR scheduler, and weight-decay scheduler for JEPA.
Parameters are grouped to exclude bias/norm tensors from weight decay, matching typical transformer training best practices.
Benchmarks and reports#
Benchmarks sub-module - Benchmark runners and adapters for world models.
This package provides tools for running standardized evaluations of world models (Dreamer, IRIS, DIAMOND) across multiple seeds and computing aggregate metrics.
- Usage:
from world_models.benchmarks import BenchmarkRunner, DiamondAdapter runner = BenchmarkRunner(adapter_cls=DiamondAdapter, …)
- class world_models.benchmarks.BenchmarkRunner(adapter_cls, out_dir='results')[source]#
Bases:
objectRun evaluations for adapters across seeds and export results.
- Usage:
runner = BenchmarkRunner(adapter_cls=adapters.DiamondAdapter) results = runner.run(games=[“Breakout-v5”], seeds=[0,1], episodes=5)
- Parameters:
adapter_cls (Callable[..., adapters.BaseAdapter])
out_dir (str)
- run(env_spec=None, seeds=None, num_episodes=5, checkpoint=None, extra_kwargs=None)[source]#
Run benchmark.
Returns a results dict with per-seed episode returns and computed metrics.
- Parameters:
env_spec (Any | None)
seeds (List[int] | None)
num_episodes (int)
checkpoint (str | None)
extra_kwargs (Dict[str, Any] | None)
- Return type:
Dict[str, Any]
- class world_models.benchmarks.MultiAgentBenchmarkRunner(adapters, out_dir='results')[source]#
Bases:
objectRun evaluations for multiple adapters on the same environment.
- Usage:
runner = MultiAgentBenchmarkRunner(adapters=[adapters.DiamondAdapter, adapters.IRISAdapter]) results = runner.run_all(game=”Breakout-v5”, seeds=[0,1], episodes=5)
- Parameters:
adapters (List[Callable[..., adapters.BaseAdapter]])
out_dir (str)
- run_all(env_spec, seeds=None, num_episodes=5, checkpoints=None, extra_kwargs=None, train_epochs=None)[source]#
Run benchmarks for all adapters on the same environment.
Returns a results dict with results for each adapter.
- Parameters:
env_spec (Dict[str, Any])
seeds (List[int] | None)
num_episodes (int)
checkpoints (Dict[str, str] | None)
extra_kwargs (Dict[str, Any] | None)
train_epochs (int | None)
- Return type:
Dict[str, Any]
- class world_models.benchmarks.BaseAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
object- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.DreamerAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.IRISAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.DiamondAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
- world_models.benchmarks.compute_aggregate_metrics(per_seed_means)[source]#
- Parameters:
per_seed_means (Iterable[float])
- Return type:
Dict[str, float]
- world_models.benchmarks.bootstrap_ci(values, num_samples=1000, alpha=0.05)[source]#
Compute simple bootstrap 1-alpha CI on the mean.
- Parameters:
values (List[float])
num_samples (int)
alpha (float)
- world_models.benchmarks.iqm_of_array(values)[source]#
Compute the Interquartile Mean (IQM) of an array of values.
IQM is the mean of values that lie between the 25th and 75th percentiles (inclusive). This is a robust central tendency measure used in RL benchmark reporting.
- Parameters:
values (Iterable[float])
- Return type:
float
- class world_models.benchmarks.runner.BenchmarkRunner(adapter_cls, out_dir='results')[source]#
Bases:
objectRun evaluations for adapters across seeds and export results.
- Usage:
runner = BenchmarkRunner(adapter_cls=adapters.DiamondAdapter) results = runner.run(games=[“Breakout-v5”], seeds=[0,1], episodes=5)
- Parameters:
adapter_cls (Callable[..., adapters.BaseAdapter])
out_dir (str)
- run(env_spec=None, seeds=None, num_episodes=5, checkpoint=None, extra_kwargs=None)[source]#
Run benchmark.
Returns a results dict with per-seed episode returns and computed metrics.
- Parameters:
env_spec (Any | None)
seeds (List[int] | None)
num_episodes (int)
checkpoint (str | None)
extra_kwargs (Dict[str, Any] | None)
- Return type:
Dict[str, Any]
- class world_models.benchmarks.runner.MultiAgentBenchmarkRunner(adapters, out_dir='results')[source]#
Bases:
objectRun evaluations for multiple adapters on the same environment.
- Usage:
runner = MultiAgentBenchmarkRunner(adapters=[adapters.DiamondAdapter, adapters.IRISAdapter]) results = runner.run_all(game=”Breakout-v5”, seeds=[0,1], episodes=5)
- Parameters:
adapters (List[Callable[..., adapters.BaseAdapter]])
out_dir (str)
- run_all(env_spec, seeds=None, num_episodes=5, checkpoints=None, extra_kwargs=None, train_epochs=None)[source]#
Run benchmarks for all adapters on the same environment.
Returns a results dict with results for each adapter.
- Parameters:
env_spec (Dict[str, Any])
seeds (List[int] | None)
num_episodes (int)
checkpoints (Dict[str, str] | None)
extra_kwargs (Dict[str, Any] | None)
train_epochs (int | None)
- Return type:
Dict[str, Any]
- class world_models.benchmarks.adapters.BaseAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
object- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.adapters.DiamondAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.adapters.IRISAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.adapters.DreamerAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.adapters.DreamerV1Adapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
DreamerAdapter- Parameters:
env_spec (Any | None)
seed (int)
- class world_models.benchmarks.adapters.DreamerV2Adapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
DreamerAdapter- Parameters:
env_spec (Any | None)
seed (int)
- world_models.benchmarks.metrics.compute_aggregate_metrics(per_seed_means)[source]#
- Parameters:
per_seed_means (Iterable[float])
- Return type:
Dict[str, float]
- world_models.benchmarks.metrics.bootstrap_ci(values, num_samples=1000, alpha=0.05)[source]#
Compute simple bootstrap 1-alpha CI on the mean.
- Parameters:
values (List[float])
num_samples (int)
alpha (float)
- world_models.benchmarks.metrics.iqm_of_array(values)[source]#
Compute the Interquartile Mean (IQM) of an array of values.
IQM is the mean of values that lie between the 25th and 75th percentiles (inclusive). This is a robust central tendency measure used in RL benchmark reporting.
- Parameters:
values (Iterable[float])
- Return type:
float
- world_models.benchmarks.metrics.bootstrap_iqm_ci(values, num_samples=1000, alpha=0.05)[source]#
Bootstrap a confidence interval for the IQM.
Returns (lower, upper) percentiles of the bootstrap IQM distribution.
- Parameters:
values (List[float])
num_samples (int)
alpha (float)
- world_models.benchmarks.reporting.export_csv(results, path)[source]#
- Parameters:
results (Dict[str, Any])
path (str)
Utilities#
- world_models.utils.dreamer_utils.symlog(x)[source]#
Symmetric log transform used by Dreamer V2 for reward/value targets.
Defined as
sign(x) * log(1 + |x|). This compresses large positive or negative values into a range that is easier to predict with a categorical distribution over a bounded set of buckets.- Parameters:
x (Tensor)
- Return type:
Tensor
- world_models.utils.dreamer_utils.symexp(x)[source]#
Inverse of
symlog().Defined as
sign(x) * (exp(|x|) - 1).- Parameters:
x (Tensor)
- Return type:
Tensor
- class world_models.utils.dreamer_utils.TwoHotEncoder(num_buckets=255, symlog_range=10.0)[source]#
Bases:
objectTwo-hot encoding for symlog targets (Dreamer V2 reward/value heads).
A target value is softly assigned to the two nearest buckets on a uniform grid spanning
[-symlog_range, symlog_range]. The categorical logits produced by a network can then be decoded back into a real value by computing the expected bucket center.- Parameters:
num_buckets (int) – Number of buckets in the categorical distribution.
symlog_range (float) – Maximum absolute value (in symlog space) covered by the grid. Values outside the range are clipped to the boundary buckets.
- register_buffers()[source]#
Allocate the bucket-center buffer on CPU. Use
to()to move.- Return type:
None
- encode(target)[source]#
Two-hot encode a real-valued target into soft bucket probabilities.
- Parameters:
target (Tensor) – Tensor of arbitrary shape containing real-valued targets.
- Returns:
Tensor with an extra final dimension of size
num_bucketscontaining the soft two-hot distribution. The encoding assumes the target is already in symlog space, matching Dreamer V2.- Return type:
Tensor
- decode(logits)[source]#
Decode categorical logits into the expected real-valued prediction.
The logits are first softmaxed and then combined with the bucket centers. The output is passed through
symexp()to invert the symlog transform.- Parameters:
logits (Tensor) – Tensor with a final dimension of
num_buckets.- Returns:
Tensor with the same shape as
logitsminus the last dimension.- Return type:
Tensor
- world_models.utils.dreamer_utils.get_parameters(modules)[source]#
Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters
- Parameters:
modules (Iterable[Module])
- class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]#
Bases:
objectContext manager that temporarily disables gradients for given modules.
Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.
- Parameters:
modules (Iterable[Module])
- class world_models.utils.dreamer_utils.Logger(log_dir, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', video_format='gif', video_fps=20)[source]#
Bases:
objectExperiment logger for scalars and GIF rollouts using WandB.
Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.
- world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]#
Compute TD(lambda) returns from imagined rewards, values, and discounts.
Implements backward recursion used by Dreamer actor/value objectives.
- world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]#
Initialize a tensor in-place from a truncated normal distribution.
Values are sampled from N(mean, std) and clipped to [a, b].
- world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]#
Repeat each batch chunk multiple times while preserving chunk ordering.
Used in JEPA masking code to align context and target token batches.
- class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]#
Bases:
objectLearning-rate schedule with linear warmup followed by cosine decay.
Updates optimizer parameter-group LRs on each call to step().
- class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]#
Bases:
objectCosine scheduler for optimizer weight decay values.
Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.
- world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]#
Measure CUDA execution time for a closure and return (result, elapsed_ms).
Falls back to -1 elapsed time when CUDA timing is unavailable.
- class world_models.utils.jepa_utils.CSVLogger(fname, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', *argv)[source]#
Bases:
objectLightweight CSV logger with per-column printf-style formatting and WandB support.
- class world_models.utils.jepa_utils.AverageMeter[source]#
Bases:
objectTrack running statistics (val, avg, min, max, sum, count) for metrics.
- world_models.utils.jepa_utils.grad_logger(named_params)[source]#
Aggregate gradient norm statistics over model parameters for logging.
Also exposes first/last qkv-layer gradient norms when available.
- world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]#
Initialize torch distributed process groups when environment supports it.
Returns (world_size, rank) and gracefully falls back to single-process mode.
- class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]#
Bases:
FunctionAutograd-aware all-gather operation across distributed workers.
Forward concatenates worker tensors; backward reduces and slices gradients.
- class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]#
Bases:
FunctionAutograd function that sums tensors across distributed workers in forward pass.
- class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]#
Bases:
FunctionAutograd function that all-reduces and averages tensors across workers.
Used to synchronize scalar losses for consistent distributed logging/training.
- world_models.utils.data_utils.create_efficient_dataloader(dataset, batch_size, num_workers=None, pin_memory=True, prefetch_factor=2, persistent_workers=True)[source]#
Create a memory-efficient and fast DataLoader.
- Parameters:
dataset (Dataset)
batch_size (int)
num_workers (int | None)
pin_memory (bool)
prefetch_factor (int)
persistent_workers (bool)
- Return type:
DataLoader
- world_models.utils.data_utils.prefetch_iterator(iterator, buffer_size=3)[source]#
Add prefetching to any iterator.
- Parameters:
iterator (Iterator)
buffer_size (int)
- world_models.utils.jit_utils.jit_compile_function(func)[source]#
JIT compile a function for performance.
- Parameters:
func (Callable)
- Return type:
Callable
- world_models.utils.jit_utils.jit_compile_module(module)[source]#
JIT compile a PyTorch module.
- Parameters:
module (Module)
- Return type:
Module
- world_models.utils.memory_utils.apply_gradient_checkpointing(model, checkpoint_ratio=0.5)[source]#
Apply gradient checkpointing to reduce memory usage during training.
- Parameters:
model (Module)
checkpoint_ratio (float)
- world_models.utils.memory_utils.enable_mixed_precision(model, scaler=None)[source]#
Enable mixed precision training.
- Parameters:
model (Module)
scaler (GradScaler | None)
- world_models.utils.memory_utils.optimize_memory_efficient_ops()[source]#
Set PyTorch for memory-efficient operations.
- world_models.utils.logging_utils.setup_logging(name, level='INFO', log_file=None)[source]#
Set up consistent logging for torchwm modules.
- Parameters:
name (str)
level (str)
log_file (str | None)
- Return type:
Logger
- world_models.utils.utils.to_tensor_obs(image)[source]#
Converts the input np img to channel first 64x64 dim torch img.
- world_models.utils.utils.postprocess_img(image, depth)[source]#
Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])
- world_models.utils.utils.preprocess_img(image, depth)[source]#
Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !!
- world_models.utils.utils.bottle(func, *tensors)[source]#
Evaluates a func that operates in N x D with inputs of shape N x T x D
- world_models.utils.utils.get_combined_params(*models)[source]#
Returns the combine parameter list of all the models given as input.
- world_models.utils.utils.save_video(frames, path, name)[source]#
Saves a video containing frames.
- Accepts frames in either:
(T, C, H, W) float in [0,1]
(T, H, W, C) float in [0,1]
Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout.
- world_models.utils.utils.combine_videos(video_dir, output_name='combined.mp4', pattern='vid_*.mp4', fps=25, resize=True)[source]#
Combine all videos matching pattern in video_dir into a single MP4 file. Returns the output filepath (string).
Example
combine_videos(“results/planet”, output_name=”all_training.mp4”)
- world_models.utils.utils.ensure_results_dir_exists(results_dir)[source]#
Simple helper to validate a results directory exists. Raises FileNotFoundError if not present.
- world_models.utils.utils.save_frames(target, pred_prior, pred_posterior, name, n_rows=5)[source]#
Save side-by-side target, prior-prediction, and posterior-prediction frames.
The function accepts tensors with optional time dimension and writes a PNG grid to
{name}.png. Spatial sizes are aligned per timestep before concatenation and values are normalized to[0, 1]when needed.
- world_models.utils.utils.get_mask(tensor, lengths)[source]#
Build a batch-first validity mask from sequence lengths.
tensormay be a tensor/array with shape(N, T, ...)or(N,). The returned mask marks valid timesteps with ones up to each element inlengthsand preserves device/dtype conventions from the input.
- world_models.utils.utils.load_memory(path, device)[source]#
Loads an experience replay buffer (backwards-compatible with older pickle formats). Converts legacy list/.data formats into the current Memory(episodes) object.
- world_models.utils.utils.flatten_dict(data, sep='.', prefix='')[source]#
Flattens a nested dict into a single-level dict.
Example
{‘a’: 2, ‘b’: {‘c’: 20}} -> {‘a’: 2, ‘b.c’: 20}
- world_models.utils.utils.normalize_frames_for_saving(frames)[source]#
Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.
- class world_models.utils.utils.TensorBoardMetrics(path)[source]#
Bases:
objectPlots and (optionally) stores metrics for an experiment.
- world_models.utils.utils.apply_model(model, inputs, ignore_dim=None)[source]#
Placeholder helper for generic model application across input structures.
Currently not implemented; kept as an extension hook for future utility code.
- world_models.utils.utils.plot_metrics(metrics, path, prefix)[source]#
Render and save line plots for each metric series in a dictionary.
- world_models.utils.utils.lineplot(xs, ys, title, path='', xaxis='episode')[source]#
Create a Plotly line plot for scalar, dict, or ensemble-series data.
Supports uncertainty-band plotting when ys is a 2D array.
- class world_models.utils.utils.TorchImageEnvWrapper(env, bit_depth, observation_shape=None, act_rep=2)[source]#
Bases:
objectTorch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form.
- property observation_size#
- property action_size#
- property max_episode_steps#
Return environment max episode steps (compatible with TimeLimit/spec).
- world_models.utils.utils.apply_masks(x, masks)[source]#
Gather token subsets from patch sequences using index masks.
Each mask selects token positions from x; selected groups are concatenated along the batch dimension.
- world_models.utils.utils.visualize_latent_tsne(latents, labels=None, save_path=None, perplexity=30)[source]#
Visualize latent representations using t-SNE.
- Parameters:
latents – torch.Tensor of shape (N, D) or numpy array
labels – optional list or array of labels for coloring
save_path – path to save the plot (HTML for plotly)
perplexity – t-SNE perplexity parameter
- world_models.utils.utils.visualize_latent_umap(latents, labels=None, save_path=None, n_neighbors=15)[source]#
Visualize latent representations using UMAP.
- Parameters:
latents – torch.Tensor of shape (N, D) or numpy array
labels – optional list or array of labels for coloring
save_path – path to save the plot (HTML for plotly)
n_neighbors – UMAP n_neighbors parameter
- class world_models.utils.utils.StreamingVideoWriter(path, fps=20, frame_shape=None, format='mp4')[source]#
Bases:
objectA class for streaming video writing to save frames in real-time.
- Parameters:
path – output video file path
fps – frames per second
frame_shape – (height, width) of frames
format – ‘mp4’ or ‘avi’