API Reference#
This reference is generated from source docstrings and grouped by workflow. Use World Models Study Guide for conceptual explanations and this page for exact classes, functions, and module-level APIs.
Public package surface#
These modules expose the most common imports and lazy constructors.
Primary module: torchwm. Implementation modules are documented below for API completeness.
Use torchwm for common workflows:
import torchwm
agent = torchwm.create_model("dreamer", env="walker-walk")
Friendly top-level package for TorchWM.
torchwm is the recommended public namespace for users:
import torchwm
agent = torchwm.create_model("dreamer", env="walker-walk")
- class torchwm.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing an environment backend available through
make_env.- Parameters:
name (str)
factory_path (str)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- factory_path: str#
Alias for field number 1
- description: str#
Alias for field number 2
- aliases: tuple[str, ...]#
Alias for field number 3
- class torchwm.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing a model available through
create_model().- Parameters:
name (str)
import_path (str)
config_path (str | None)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- import_path: str#
Alias for field number 1
- config_path: str | None#
Alias for field number 2
- description: str#
Alias for field number 3
- aliases: tuple[str, ...]#
Alias for field number 4
- torchwm.create_config(model, **overrides)[source]#
Create the default config object for
modeland apply overrides.Examples
>>> cfg = create_config("dreamer", env="walker-walk", seed=7) >>> cfg.env 'walker-walk'
- Parameters:
model (str)
overrides (Any)
- Return type:
Any
- torchwm.create_model(model, config=None, **overrides)[source]#
Instantiate a model or agent from a simple string name.
configis optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.Examples
>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000) >>> genie = create_model("genie-small", image_size=32)
- Parameters:
model (str)
config (Any | None)
overrides (Any)
- Return type:
Any
- torchwm.get_env_backend_spec(name)[source]#
Return metadata for an environment backend name or alias.
- Parameters:
name (str)
- Return type:
- torchwm.get_model_spec(name)[source]#
Return metadata for a model name or alias.
- Parameters:
name (str)
- Return type:
- torchwm.list_env_backends()[source]#
Return canonical backend names accepted by
make_env().- Return type:
list[str]
- torchwm.list_envs(model=None)[source]#
List known environment ids, optionally filtered by model family.
- Parameters:
model (str | None)
- Return type:
list[str] | dict[str, list[str]]
- torchwm.list_models()[source]#
Return canonical model names accepted by
create_model().- Return type:
list[str]
- torchwm.make_env(env_id, backend='auto', **kwargs)[source]#
Create an environment with a consistent TorchWM entry point.
- Parameters:
env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.
backend (str) – One of
list_env_backends();"auto"tries TorchWM’s compatibility helper.**kwargs (Any) – Backend-specific options.
- Return type:
Any
- torchwm.export_any(obj, path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export any TorchWM model/agent or a target module contained by it.
- Parameters:
obj (Any)
path (str | Path)
format (str)
example_inputs (Any | None)
target (str | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- torchwm.export_model(module, path, format='onnx', *, example_inputs=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export a
torch.nn.Moduleto ONNX, TorchScript, or TensorRT.- Parameters:
module (Module)
path (str | Path)
format (str)
example_inputs (Any | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- class torchwm.ExportableAgentMixin[source]#
Bases:
objectMixin for non-
nn.Moduleagents that delegates to the shared exporter.- export(path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export this agent or one of its contained modules for deployment.
- Parameters:
path (str | Path)
format (str)
example_inputs (Any | None)
target (str | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- class torchwm.IRISAgent(config, action_size, device)[source]#
Bases:
ModuleComplete IRIS Agent with world model and policy.
Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning
- Parameters:
config (IRISConfig)
action_size (int)
device (device)
- classmethod from_config(config=None, *, action_size, device=None, **overrides)[source]#
Build an IRIS agent from a config object, dict, YAML file, or YAML string.
- Parameters:
config (IRISConfig | dict[str, Any] | str | Path | None)
action_size (int)
device (device | str | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, action_size=None, device=None, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#
Load an IRIS agent checkpoint from a local path/directory or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
action_size (int | None)
device (device | str | None)
config (IRISConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
overrides (Any)
- Return type:
- forward_actor_critic(frames, hidden=None)[source]#
Forward pass through actor-critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state
- Returns:
(B, T, action_size) values: (B, T) hidden_state: (h, c)
- Return type:
action_logits
- act(frame, epsilon=0.0, temperature=1.0)[source]#
Sample action from policy.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
epsilon (float) – Random action probability
temperature (float) – Action distribution temperature
- Returns:
Selected actions (B,)
- Return type:
- imagine_rollout(initial_frame, horizon=20)[source]#
Generate imagined trajectories using world model.
- Parameters:
initial_frame (Tensor) – Starting frame (B, C, H, W)
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined rollout data
- Return type:
trajectory
- update_autoencoder(frames)[source]#
Update discrete autoencoder.
- Parameters:
frames (Tensor) – Training frames (B, C, H, W)
- Returns:
Dictionary of loss values
- Return type:
losses
- update_transformer(frames, actions, rewards, terminals)[source]#
Update transformer world model.
- Parameters:
frames (Tensor) – Frame sequence
actions (Tensor) – Actions taken
rewards (Tensor) – Rewards received
terminals (Tensor) – Terminal flags
- Returns:
Dictionary of loss values
- Return type:
losses
- torchwm.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#
Compute λ-return target for value function training.
- Parameters:
rewards (Tensor) – Rewards (B, T)
values (Tensor) – Value estimates (B, T+1)
discounts (Tensor) – Discount factors (B, T)
lambda_coef (float) – Lambda parameter for bootstrapping
- Returns:
λ-return targets (B, T)
- Return type:
- class torchwm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#
Bases:
ModuleModular RSSM with swappable encoder, decoder, and backbone.
This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer
Example
>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024) >>> decoder = ConvDecoder(32, 200, (3, 64, 64)) >>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024) >>> rssm = ModularRSSM(encoder, decoder, backbone)
- Parameters:
encoder (EncoderBase)
decoder (DecoderBase)
backbone (BackboneBase)
reward_decoder (DecoderBase | None)
- property stoch_size: int#
- property deter_size: int#
- property embed_size: int#
- init_state(batch_size, device)[source]#
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
obs (Tensor)
nonterm (Any)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
nonterm (Any)
- Return type:
Dict[str, Tensor]
- observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
- Parameters:
obs (Tensor)
actions (Tensor)
nonterms (Tensor)
prev_state (Dict[str, Tensor])
horizon (int)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- torchwm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#
Factory function to create a modular RSSM with specified components.
- Parameters:
encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)
decoder_type (str) – Type of decoder (“conv”, “mlp”)
backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)
obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state
action_size (int) – Action space dimension
stoch_size (int) – Stochastic latent dimension
deter_size (int) – Deterministic hidden dimension
embed_size (int) – Encoder embedding dimension
hidden_size (int) – Hidden layer dimension
activation (str) – Activation function name
kwargs (Any)
- Returns:
Configured ModularRSSM instance
- Return type:
- class torchwm.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleGenie: Generative Interactive Environment.
A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions
Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).
Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)
The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse
At inference, latent actions are stopgrad’d when passed to dynamics model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_decoder_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
encoder_depth (int)
decoder_depth (int)
latent_action_depth (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- classmethod from_config(config=None, **overrides)[source]#
Build Genie from a config object, dict, YAML file, or YAML string.
- Parameters:
config (GenieConfig | GenieSmallConfig | dict[str, Any] | str | Path | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Load Genie weights from a local path/directory or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
config (GenieConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
map_location (str | device | None)
overrides (Any)
- Return type:
- save_pretrained(path)[source]#
Save Genie weights and config in a from_pretrained-compatible format.
- Parameters:
path (str | Path)
- Return type:
None
- forward(video, mask_prob=0.5, training_phase='all')[source]#
Full forward pass through all components.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing losses and predictions
- Return type:
Dict[str, Tensor]
- training_step(video, mask_prob=0.5, training_phase='all')[source]#
Single training step computing all losses.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing all losses for backpropagation
- Return type:
Dict[str, Tensor]
- encode_video(video)[source]#
Encode video to discrete tokens.
- Parameters:
video (Tensor) – (B, C, T, H, W)
- Returns:
(B, T, H*W)
- Return type:
video_tokens
- infer_actions(frames)[source]#
Infer latent actions from a sequence of frames.
- Parameters:
frames (Tensor) – (B, C, T, H, W) video frames
- Returns:
(B, T-1) inferred latent action indices
- Return type:
latent_actions
- generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#
Generate video frames given a prompt frame and actions.
- Parameters:
prompt_frame (Tensor) – (B, C, H, W) initial frame
num_frames (int) – Total number of frames to generate
actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random
use_maskgit (bool) – Whether to use MaskGIT sampling
- Returns:
(B, C, num_frames, H, W)
- Return type:
generated_video
- play(current_frame, action, current_frames=None)[source]#
Play step - generate next frame given current frame and action.
- Parameters:
current_frame (Tensor) – (B, C, H, W) current frame
action (Tensor) – (B,) latent action indices
current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame
- Returns:
(B, C, H, W)
- Return type:
next_frame
- class torchwm.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleLatent Action Model (LAM) for unsupervised action learning.
Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.
Based on Genie paper - learns actions without action labels from Internet videos.
Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- encode(x_prev, x_next)[source]#
Encode frames to latent actions.
- Parameters:
x_prev (Tensor) – Previous frames (B, C, T, H, W)
x_next (Tensor) – Next frame (B, C, H, W)
- Returns:
Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)
- Return type:
latent_actions
- class torchwm.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=True)[source]#
Bases:
ModuleDynamics Model for action-controllable video generation.
A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.
Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
gradient_checkpointing (bool)
- forward(video_tokens, actions, mask_prob=0.0)[source]#
Forward pass for training.
- Parameters:
video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T
actions (Tensor) – (B, T) - latent action indices for frames 1 to T
mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)
- Returns:
(B, T, H*W, vocab_size)
- Return type:
logits
- sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#
Sample future frames using MaskGIT.
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
sampler (MaskGITSampler | None) – MaskGIT sampler instance
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#
Simple autoregressive sampling (token by token).
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
temperature (float) – Sampling temperature
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- torchwm.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Create a smaller Genie model for development/testing.
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#
Create the full 11B parameter Genie model (approximate).
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- torchwm.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#
Factory function to create a Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
- Return type:
- class torchwm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#
Bases:
ModuleRecurrent State-Space Model used by Dreamer for latent dynamics learning.
The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:
Deterministic State (h) – A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.
Stochastic State (s) – A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).
The model operates in two modes:
Observe Mode – Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)
Imagine Mode – Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)
Architecture
Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1}
Process: GRU updates deterministic state, MLP computes stochastic prior/posterior
Output: Updated state (h_t, s_t) and distributions
State Representation
deter (h): GRU hidden state, captures sequential context
stoch (s): Stochastic latent, multi-modal uncertainty
mean/std: Parameters of the stochastic distribution
Usage with DreamerAgent:
rssm = RSSM( action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation='elu' ) # Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed) # Imagine future without observation prior = rssm.imagine_step(current_state, action)
Training
The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to
capture environment dynamics
Reconstruction loss from decoder ensures state captures observation info
- Reference:
Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
hidden_size (int)
obs_embed_size (int)
activation (str)
- init_state(batch_size, device)[source]#
Initialize RSSM state with zeros.
- Parameters:
batch_size (int) – Number of parallel sequences
device (device) – torch device for tensors
- Returns:
mean, std: Stochastic distribution parameters
stoch: Stochastic state sample
deter: Deterministic GRU hidden state
- Return type:
Dictionary containing zero-initialized state components
- get_dist(mean, std)[source]#
Create an Independent Normal distribution from mean and std.
- Parameters:
mean (Tensor) – Location parameter
std (Tensor) – Scale parameter
- Returns:
Independent Normal distribution with given parameters
- Return type:
Independent
- observe_step(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#
Update state using actual observation (observe mode).
In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.
- Parameters:
prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)
obs_embed (Tensor) – Observation embedding from encoder, shape (B, obs_embed_size)
nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
A tuple
(posterior, prior)of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.- Return type:
Tuple[dict, dict]
- imagine_step(prev_state, prev_action, nonterm=tensor(1.))[source]#
Predict next state without observation (imagine mode).
In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.
- Parameters:
prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)
nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
deter: Predicted deterministic state
mean, std, stoch: Prior stochastic state distribution
- Return type:
Dictionary with predicted state containing
- get_prior(prev_state, prev_action, nonterm=tensor(1.))[source]#
- Parameters:
prev_state (dict)
prev_action (Tensor)
nonterm (Tensor)
- Return type:
dict
- get_posterior(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#
Compute posterior distribution over stochastic state.
The posterior incorporates observation information to produce a more accurate state estimate.
- Parameters:
prev_state (dict) – Previous state dictionary
prev_action (Tensor) – Previous action
obs_embed (Tensor) – Observation embedding
nonterm (Tensor) – Termination mask
- Returns:
Dictionary with posterior state (observation-informed). Note that the previous-state shape
(B, ...)is preserved; the batch dimension is not flattened.- Return type:
dict
- detach_state(state)[source]#
Detach state tensors from computation graph.
Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.
- Parameters:
state (dict) – State dictionary with tensor values
- Returns:
Detached state dictionary
- Return type:
dict
- seq_to_batch(state_dict)[source]#
Convert sequence state to batch format.
- Parameters:
state_dict (dict) – Dictionary with sequence-dimension tensors (T, B, …)
- Returns:
Dictionary with batch-dimension tensors (B*T, …)
- Return type:
dict
- observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#
Process a sequence of observations (observe mode rollout).
At each timestep we run
observe_steponce to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.- Parameters:
obs_embed (Tensor) – Observation embeddings, shape (T+1, B, obs_embed_size)
actions (Tensor) – Actions, shape (T, B, action_size)
nonterms (Tensor) – Non-termination flags, shape (T, B, 1)
init_state (dict) – Initial state dictionary
seq_len (int) – Sequence length T
- Returns:
Dictionary with prior states stacked along the time axis posterior: Dictionary with posterior states stacked along the time axis
- Return type:
prior
- imagine_rollout(policy, init_state, horizon)[source]#
Generate imagined trajectory using policy (imagine mode rollout).
- Parameters:
policy (Module) – Actor network that outputs actions from state features
init_state (dict) – Initial state dictionary
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined states for each step
- Return type:
dict
- forward(x, u)[source]#
Forward pass for training (computes sequence of states).
- Parameters:
x (Tensor) – Observations, shape (B, T+1, C, H, W)
u (Tensor) – Actions, shape (B, T, action_size)
- Returns:
List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)
- Return type:
states
- class torchwm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#
Bases:
ModuleA Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.
- Parameters:
action_size (int)
state_size (int)
latent_size (int)
hidden_size (int)
embed_size (int)
activation_function (str)
- get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#
Returns the initial posterior given the observation.
- Parameters:
enc (Tensor)
h_t (Tensor | None)
s_t (Tensor | None)
a_t (Tensor | None)
mean (bool)
- Return type:
tuple[Tensor, Tensor]
- deterministic_state_fwd(h_t, s_t, a_t)[source]#
Deterministic transition update.
Ensures a_t is 2D and matches batch dimension of h_t before concatenation. Accepts a_t shaped [B, action_size], [action_size] (expanded to [B, action_size]), or [B]/scalar (reshaped appropriately).
- Parameters:
h_t (Tensor)
s_t (Tensor)
a_t (Tensor)
- Return type:
Tensor
- state_prior(h_t, sample=False)[source]#
Returns the prior distribution over the latent state given the deterministic state
- Parameters:
h_t (Tensor)
sample (bool)
- Return type:
tuple[Tensor, Tensor] | Tensor
- state_posterior(h_t, e_t, sample=False)[source]#
Returns the state prior given the deterministic state and obs
- Parameters:
h_t (Tensor)
e_t (Tensor)
sample (bool)
- Return type:
tuple[Tensor, Tensor] | Tensor
- rollout_prior(act, h_t, s_t)[source]#
- Parameters:
act (Tensor)
h_t (Tensor)
s_t (Tensor)
- Return type:
tuple[Tensor, Tensor]
- forward(x, u)[source]#
Forward through the RSSM for a batch of sequences.
- Parameters:
x (Tensor) – Tensor [B, T+1, C, H, W] (observations including initial frame)
u (Tensor) – Tensor [B, T, action_size] (actions for T steps)
- Returns:
list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]
- Return type:
states
- class torchwm.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).
Input: (B, C, H, W) raw images, values in [-0.5, 0.5]
Process: 4 convolutional layers with stride 2, halving spatial dimensions
Output: (B, embed_size) compact representation
The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.
Usage with Dreamer:
encoder = ConvEncoder( input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation='relu' # Activation function ) obs_embedding = encoder(observation) # (B, 256)
- Parameters:
input_shape (tuple) – Tuple (C, H, W) for input images, typically (3, 64, 64)
embed_size (int) – Output embedding dimension, typically 256 or 1024
activation (str) – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)
depth (int) – Base channel depth for first layer (default 32)
- class torchwm.CNNEncoder(embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) encoder for processing image inputs.
- Parameters:
embedding_size (int)
activation_function (str)
- class torchwm.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#
Bases:
ModuleConvolutional decoder for reconstructing observations from latent states.
Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.
Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter)
Process: Dense projection + 4 transposed convolutions (upsampling 2x each)
Output: Independent Normal distribution over observation pixels
The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.
Returns
torch.distributions.Independent(Normal(mean, std), len(shape))allowing log_prob(observation) computation for reconstruction loss.Usage in Dreamer world model:
decoder = ConvDecoder( stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation='relu' ) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)
The reconstruction loss is
-log_prob(observation), which encourages the RSSM to learn states that capture observation information.- Parameters:
stoch_size (int)
deter_size (int)
output_shape (tuple[int, ...])
activation (str)
depth (int)
- class torchwm.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) decoder for reconstructing image outputs.
- Parameters:
state_size (int)
latent_size (int)
embedding_size (int)
activation_function (str)
- class torchwm.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.
Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter)
Process: MLP with configurable layers and hidden units
Output: Predicted quantity with distribution (normal, binary, or raw)
Supports three output types: -
'normal': Gaussian distribution for regression (rewards, values) -'binary': Bernoulli distribution for binary classification (discount) -'none': Raw tensor for non-probabilistic outputsUsage:
reward_decoder = DenseDecoder( stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation='elu', dist='normal' ) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)
For discount prediction (binary):
discount_decoder = DenseDecoder( stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation='elu', dist='binary' )
- Parameters:
stoch_size (int)
deter_size (int)
output_shape (tuple[int, ...])
n_layers (int)
units (int)
activation (str)
dist (str)
num_buckets (int)
symlog_range (float)
- class torchwm.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
n_layers (int)
units (int)
activation (str)
min_std (float)
init_std (float)
mean_scale (float)
- class torchwm.TanhBijector[source]#
Bases:
TransformBijective tanh transform for squashing Gaussian distributions to [-1, 1].
This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:
Bijective mapping: tanh is invertible (with atanh as inverse)
Stable log-det Jacobian: Computable for gradient-based training
Clipped actions: During inference, actions are naturally bounded
Forward: y = tanh(x)
Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y))
Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))
Usage with Dreamer ActionDecoder:
dist = TransformedDistribution( Normal(mean, std), TanhBijector() ) action = dist.sample() # Bounded to [-1, 1]
- Reference:
Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.
- property sign: int#
- class torchwm.SampleDist(dist, samples=100)[source]#
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- Parameters:
dist (Any)
samples (int)
- property name: str#
- class torchwm.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Encoder for IRIS discrete autoencoder.
Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.
- Architecture:
4 convolutional layers with residual blocks
Self-attention at 8x8 and 16x16 resolutions
Vector quantization to produce discrete tokens
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
in_channels (int)
base_channels (int)
num_residual_blocks (int)
frame_shape (Tuple[int, int, int])
- forward(x)[source]#
Encode images to discrete tokens.
- Parameters:
x (Tensor) – Input images (B, C, H, W) - should be 64x64
- Returns:
Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- class torchwm.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Decoder for IRIS discrete autoencoder.
Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.
- Parameters:
vocab_size (int)
embedding_dim (int)
base_channels (int)
out_channels (int)
frame_shape (Tuple[int, int, int])
- forward(z)[source]#
Decode tokens to images.
- Parameters:
z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)
- Returns:
Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)
- Return type:
reconstructed
- class torchwm.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#
Bases:
ModuleVideo Tokenizer using VQ-VAE with Spatiotemporal Transformer.
This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.
The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.
Architecture
Patch Embedding: Convert (B, C, T, H, W) video to patch tokens
Encoder ST-Transformer: Process spatial-temporal patches
Vector Quantization: Discretize continuous embeddings to codebook entries
Decoder ST-Transformer: Reconstruct video from quantized tokens
Patch Unembedding: Convert tokens back to video frames
Key Features
Causal processing: Each frame’s encoding only uses previous frames
Discrete tokens: Enables autoregressive prediction with latent actions
Memory efficient: Uses ST-Transformer instead of full ViT to reduce complexity
Usage with Genie:
tokenizer = VideoTokenizer( num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32 ) reconstructed, indices, loss_dict = tokenizer(video_frames) # For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)
The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings - Commitment loss: Penalizes encoder outputs drifting from codebook
- Reference:
Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
use_ema (bool)
ema_decay (float)
- encode(x)[source]#
Encode video to discrete tokens.
- Parameters:
x (Tensor) – Video tensor (B, C, T, H, W)
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- decode_indices(indices)[source]#
Decode token indices to embeddings for video frames.
- Parameters:
indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’ x W’
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim)
- Return type:
z_q
- torchwm.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#
Factory function to create a Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
- Return type:
- class torchwm.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#
Bases:
ModuleVector Quantizer for discrete autoencoder.
Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)
Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
- class torchwm.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#
Bases:
ModuleVector Quantizer with Exponential Moving Average updates.
Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
ema_decay (float)
epsilon (float)
- class torchwm.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#
Bases:
objectFixed-size replay buffer for Dreamer with image observations and transitions.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.
Key Features
Ring buffer with fixed capacity (FIFO eviction when full)
Stores raw uint8 images to save memory
Samples sequences (not single transitions) for temporal modeling
Validates sampled sequences don’t span episode boundaries
Memory Layout
observations: (capacity, C, H, W) uint8 images
actions: (capacity, action_dim) float32
rewards: (capacity,) float32
terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)
Sampling Process
Random start index (avoiding episode boundaries)
Collect sequence of length seq_len with wraparound
Validate no terminal in middle of sequence
Return batch of sequences
Usage with Dreamer:
buffer = ReplayBuffer( size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch ) # Add transitions during interaction buffer.add(obs, action, reward, done) # Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample()
Memory Efficiency
Uses uint8 for images (1 byte per pixel vs 4 for float32)
Sequences share observations (overlapping windows)
Configurable capacity based on available system memory
Note
The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.
- Parameters:
size (int)
obs_shape (Tuple[int, ...])
action_size (int)
seq_len (int)
batch_size (int)
- add(obs, ac, rew, done)[source]#
Add a transition to the buffer.
- Parameters:
obs (dict) – Observation dict with ‘image’ key containing the observation
ac (ndarray) – Action taken, shape (action_size,)
rew (float) – Reward received, scalar
done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise
- Return type:
None
- class torchwm.Memory(size=None)[source]#
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.
Stores complete episodes as lists of transitions
Samples contiguous sub-sequences for sequence models
Supports time-major formatting (time-first) for RNN input
Memory usage estimation to prevent OOM errors
- Parameters:
size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).
- episodes#
Collection of Episode objects.
- Type:
deque
- eps_lengths#
Length of each episode.
- Type:
deque
- size#
Total number of transitions across all episodes.
- Type:
property
Example:
memory = Memory(size=100) memory.append([episode1, episode2]) batch, lengths = memory.sample(batch_size=32, tracelen=50)
- property size: int#
- sample(batch_size, tracelen=1, time_first=False)[source]#
Sample random sub-sequences from stored episodes.
Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.
- Parameters:
batch_size (int) – Number of sequences to sample.
tracelen (int) – Length of each sequence (default: 1).
time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).
- Returns:
- (observations, actions, rewards, terminals, lengths)
observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)
actions: (batch, tracelen, action_dim) or (tracelen, batch, …)
rewards: (batch, tracelen) or (tracelen, batch)
terminals: (batch, tracelen) or (tracelen, batch)
lengths: (batch,) original episode lengths for each sample
- Return type:
tuple
- Raises:
ValueError – If memory is empty or no episodes meet minimum length.
MemoryError – If estimated memory usage exceeds 200 MiB threshold.
- class torchwm.Episode(postprocess_fn=None)[source]#
Bases:
objectRecords the agent’s interaction with the environment for a single episode.
Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.
- x#
Observations collected during the episode.
- Type:
list or np.ndarray
- u#
Actions taken.
- Type:
list or np.ndarray
- r#
Rewards received.
- Type:
list or np.ndarray
- t#
Terminal flags (0.0 = continue, 1.0 = terminal).
- Type:
list or np.ndarray
- info#
Additional episode metadata.
- Type:
dict
- Parameters:
postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.
Example:
episode = Episode() episode.append(obs, action, reward, False) episode.append(obs, action, reward, True) episode.terminate(final_obs) print(episode.x.shape) # Now a numpy array
- property size: int#
- class torchwm.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#
Bases:
objectReplay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.
- Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores uint8 images for memory efficiency
Samples sequences with validation to avoid episode boundaries
Supports sequence sampling for temporal learning
- Memory Layout:
observations: (capacity, C, H, W) uint8
actions: (capacity, action_size) float32
rewards: (capacity,) float32
terminals: (capacity,) float32
- Parameters:
size (int) – Maximum number of transitions to store.
obs_shape (tuple) – Shape of observations as (C, H, W).
action_size (int) – Dimension of actions.
seq_len (int) – Length of sequences to sample (default: 20).
batch_size (int) – Number of sequences per batch (default: 64).
- size#
Buffer capacity.
- Type:
int
- obs_shape#
Observation shape.
- Type:
tuple
- action_size#
Action dimension.
- Type:
int
- seq_len#
Sequence length.
- Type:
int
- batch_size#
Batch size.
- Type:
int
- steps#
Total transitions added.
- Type:
int
- episodes#
Number of episode terminations observed.
- Type:
int
- add(obs, action, reward, terminal)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray) – Observation array with shape (C, H, W).
action (ndarray) – Action array with shape (action_size,).
reward (float) – Scalar reward value.
terminal (bool) – Boolean indicating if episode terminated.
- Return type:
None
- sample_sequence(seq_len=None)[source]#
Sample a batch of sequences for world model training.
- Returns:
(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)
- Return type:
- Parameters:
seq_len (int | None)
- sample_single()[source]#
Sample a single transition for online updates.
- Return type:
Tuple[ndarray, ndarray, float, float]
- property buffer_capacity: int#
Returns the total capacity of the buffer.
- class torchwm.IRISOnPolicyBuffer(max_steps=1000)[source]#
Bases:
objectOn-policy buffer for collecting trajectories during environment interaction.
Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.
- Useful for:
Collecting complete episode trajectories
Storing data before batch processing
Temporary storage during environment interaction
- Parameters:
max_steps (int) – Maximum number of steps to store (default: 1000).
- max_steps#
Maximum buffer capacity.
- Type:
int
- observations#
List of observations.
- Type:
list
- actions#
List of actions.
- Type:
list
- rewards#
List of rewards.
- Type:
list
- terminals#
List of terminal flags.
- Type:
list
- class torchwm.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- Parameters:
img_size (int)
patch_size (int)
in_channels (int)
d_model (int)
depth (int)
heads (int)
drop (float)
t_dim (int)
- classmethod from_config(config=None, **overrides)[source]#
Build DiT from a config object, dict, YAML file, or YAML string.
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Load DiT weights from a local path/directory or HF Hub.
- save_pretrained(path)[source]#
Save DiT weights and config in a from_pretrained-compatible format.
- Parameters:
path (str | Path)
- Return type:
None
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
- Parameters:
epochs (int)
dataset (Any)
batch_size (int)
lr (float)
img_size (int)
channels (int)
patch (int)
width (int)
depth (int)
heads (int)
drop (float)
timesteps (int)
beta_start (float)
beta_end (float)
ema (bool)
ema_decay (float)
workdir (str)
root_path (str)
image_folder (str | None)
crop_size (int)
download (bool)
copy_data (bool)
subset_file (str | None)
val_split (float | None)
- Return type:
None
- torchwm.create_dit(config=None, **overrides)[source]#
Create a
DiTfrom aDiTConfigor keyword overrides.The public factory API works with config objects, while
DiTitself has a compact constructor. This adapter keeps the lower-level model constructor unchanged and maps the public config fields onto the expected arguments.- Parameters:
config (Any)
overrides (Any)
- Return type:
- class torchwm.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.
- Process:
Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches
Each patch is projected to embed_dim via linear layer (Conv2d)
Learnable positional embeddings are added for spatial information
Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)
- Parameters:
img_size (int) – Image size (assumes square), e.g., 32 for CIFAR
patch_size (int) – Size of each patch (typically 4, 8, or 16)
in_channels (int) – Number of input channels (3 for RGB)
embed_dim (int) – Output dimension for each patch token
- Usage with DiT:
patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4
- class torchwm.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- Parameters:
img_size (int)
patch_size (int)
embed_dim (int)
out_channels (int)
- class torchwm.DDPM(timesteps, beta_start, beta_end)[source]#
Bases:
ModuleUtility module implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- Parameters:
timesteps (int)
beta_start (float)
beta_end (float)
- q_sample(x_start, t, noise=None)[source]#
- Parameters:
x_start (Tensor)
t (Tensor)
noise (Tensor | None)
- Return type:
Tensor
- class torchwm.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#
Bases:
ModuleActor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.
- Parameters:
obs_channels (int)
action_dim (int)
channels (Tuple[int, ...])
lstm_dim (int)
- forward(obs, hidden_state=None)[source]#
Forward pass of actor-critic network.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)
- Return type:
policy_logits
- get_action(obs, hidden_state=None, deterministic=False)[source]#
Get action from a single observation.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
deterministic (bool) – If True, take argmax; else sample
- Returns:
Selected action [B] hidden_state: (h, c)
- Return type:
action
- get_actions(obs, hidden_state=None, deterministic=False)[source]#
Batched version of get_action.
- Parameters:
obs (Tensor) – Tensor of shape [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size
deterministic (bool) – If True, take argmax; else sample from policy
- Returns:
LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple
- Return type:
- get_value(obs, hidden_state=None)[source]#
Get value for a single observation.
- Parameters:
obs (Tensor)
hidden_state (Tuple[Tensor, Tensor] | None)
- Return type:
Tuple[Tensor, Tuple[Tensor, Tensor] | None]
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
Get LSTM hidden size.
- Return type:
int
- class torchwm.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#
Bases:
ModuleReward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.
- Parameters:
obs_channels (int) – Number of observation channels (3 for RGB)
action_dim (int) – Number of possible actions
channels (Tuple[int, ...]) – List of channel sizes for conv blocks
lstm_dim (int) – LSTM hidden dimension
cond_dim (int) – Conditioning dimension for adaptive norm
- forward(obs, actions, hidden_state=None)[source]#
Forward pass of reward/termination model.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
actions (Tensor) – Actions [B, T]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states
- Return type:
reward_logits
- predict(obs, actions, hidden_state=None)[source]#
Predict reward and termination for a single step.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
actions (Tensor) – Single action [B]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states
- Return type:
reward
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- torchwm.sinusoidal_time_embedding(timesteps, dim)[source]#
Create sinusoidal timestep embeddings for diffusion conditioning.
This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.
- Math:
embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)
- Parameters:
timesteps (Tensor) – Tensor of integer timesteps, shape (B,) or (B, 1)
dim (int) – Embedding dimension (must be even)
- Returns:
Tensor of shape (B, dim) with sinusoidal embeddings
- Return type:
Tensor
- Usage with DiT:
t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)
# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization
- class torchwm.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, gradient_checkpointing=False)[source]#
Bases:
ModuleSpatiotemporal Transformer for video modeling.
Contains L spatiotemporal blocks with interleaved spatial and temporal attention.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
gradient_checkpointing (bool)
- class torchwm.MultiHeadSelfAttention(d, n_heads=2)[source]#
Bases:
ModuleMulti-head scaled dot-product self-attention over sequence tokens.
This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.
- Parameters:
d (int)
n_heads (int)
- torchwm.MultiHeadAttention#
alias of
MultiHeadSelfAttention
- class torchwm.AdaLNNormalization(d_model, t_dim)[source]#
Bases:
ModuleAdaptive layer normalization conditioned on an external embedding.
The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).
- Parameters:
d_model (int)
t_dim (int)
- class torchwm.RMSNorm(dim, eps=1e-06)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization with a learned gain parameter.
RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.
- Parameters:
dim (int)
eps (float)
- class torchwm.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#
Bases:
objectModel-predictive controller using Cross-Entropy Method (CEM) with RSSM.
Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.
The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.
- Algorithm:
Initialize Gaussian distribution over action sequences
Sample N candidate action sequences
Rollout each sequence in RSSM latent space
Score by predicted cumulative rewards
Keep top K candidates, fit Gaussian to them
Repeat for T iterations
Execute first action from best sequence
- Parameters:
model (Any)
planning_horizon (int)
num_candidates (int)
num_iterations (int)
top_candidates (int)
device (device | str)
- rssm#
The RSSM world model.
- N#
Number of candidate action sequences to sample.
- K#
Number of top candidates to use for updating the proposal.
- T#
Number of CEM iterations per planning step.
- H#
Planning horizon (number of future steps to consider).
- d#
Action dimensionality.
- device#
Device to run computations on.
- state_size#
Hidden state dimensionality.
- latent_size#
Latent state dimensionality.
Example
>>> policy = RSSMPolicy( ... model=rssm, ... planning_horizon=12, ... num_candidates=1000, ... num_iterations=5, ... top_candidates=100, ... device='cuda' ... ) >>> policy.reset() >>> action = policy.poll(observation)
- reset()[source]#
Reset the policy state.
Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.
- Return type:
None
- poll(observation, explore=False)[source]#
Get action for given observation.
- Parameters:
observation (Tensor) – Current observation tensor of shape (channels, height, width).
explore (bool) – If True, add exploration noise to the selected action.
- Returns:
Action tensor of shape (1, action_size).
- Return type:
Tensor
- class torchwm.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleActor network for IRIS (Imagined Rollouts with Implicit Successor) policy.
Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.
- Architecture:
CNN: Extracts features from input frames (3x64x64 -> 512)
LSTM: Processes temporal sequences with configurable layers
Linear: Maps hidden states to action logits
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
- action_size#
Number of discrete actions.
- Type:
int
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- forward(frames, hidden_state=None, burn_in_frames=None)[source]#
Forward pass through actor.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state
burn_in_frames (Tensor | None) – Frames to use for initializing hidden state
- Returns:
Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple
- Return type:
action_logits
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- get_action(frame, temperature=1.0, deterministic=False)[source]#
Get action from a single frame.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
temperature (float) – Softmax temperature (higher = more random)
deterministic (bool) – If True, return argmax; else sample
- Returns:
Selected action indices (B,)
- Return type:
action
- class torchwm.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCritic network for IRIS value estimation.
Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.
- Architecture:
CNN: Shared feature extractor with actor (3x64x64 -> 512)
LSTM: Temporal processing with same architecture as actor
Linear: Maps hidden states to scalar values
- Parameters:
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- Returns:
Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.
- Return type:
values
- Parameters:
hidden_size (int)
num_layers (int)
frame_shape (Tuple[int, int, int])
- forward(frames, hidden_state=None)[source]#
Forward pass through critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple
- Returns:
Value estimates (B, T) hidden_state: Updated (h, c) tuple
- Return type:
values
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class torchwm.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCombined policy module for IRIS (Imagined Rollouts with Implicit Successor).
Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
Example
>>> policy = IRISPolicy( ... action_size=18, ... hidden_size=512, ... num_layers=4, ... frame_shape=(3, 64, 64) ... ) >>> action = policy.act(frame, temperature=1.0, deterministic=False)
- forward(frames)[source]#
Get action logits from frames.
- Parameters:
frames (Tensor)
- Return type:
Tensor
- act(frame, temperature=1.0, deterministic=False)[source]#
Sample action from policy.
- Parameters:
frame (Tensor)
temperature (float)
deterministic (bool)
- Return type:
Tensor
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
tuple[Tensor, Tensor]
- class torchwm.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#
Bases:
ModuleCNN feature extractor shared between actor and critic networks.
Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.
- Architecture:
Conv layers: 32 -> 64 -> 128 -> 256 channels
Each conv has stride=2 for spatial downsampling
Final linear layer projects to desired output dimension
- Parameters:
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
output_size (int) – Size of output feature vector (default: 512).
- frame_shape#
Input frame shape.
- Type:
tuple
- output_size#
Output feature dimension.
- Type:
int
- Returns:
Feature vectors with shape (B, output_size).
- Return type:
features
- Parameters:
frame_shape (Tuple[int, int, int])
output_size (int)
- class torchwm.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#
Bases:
SerializableConfigMixinConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- Parameters:
env_backend (str)
env (str)
env_instance (Any)
image_size (tuple[int, int])
gym_render_mode (str)
dmlab_action_repeat (int)
dmlab_action_set (Any)
dmlab_observations (Any)
dmlab_config (Any)
dmlab_renderer (str)
procgen_distribution_mode (str)
procgen_num_levels (int)
procgen_start_level (Any)
mujoco_xml_path (Any)
mujoco_xml_string (Any)
mujoco_binary_path (Any)
mujoco_camera (Any)
mujoco_frame_skip (int)
mujoco_reset_noise_scale (float)
brax_backend (str)
brax_jit (bool)
brax_auto_reset (bool)
brax_suppress_warp_warnings (bool)
unity_file_name (Any)
unity_behavior_name (Any)
unity_worker_id (int)
unity_base_port (int)
unity_no_graphics (bool)
unity_time_scale (float)
unity_quality_level (int)
algo (str)
exp_name (str)
train (bool)
evaluate (bool)
seed (int)
no_gpu (bool)
max_episode_length (int)
buffer_size (int)
time_limit (int)
cnn_activation_function (str)
dense_activation_function (str)
obs_embed_size (int)
num_units (int)
deter_size (int)
stoch_size (int)
action_repeat (int)
action_noise (float)
total_steps (int)
seed_steps (int)
update_steps (int)
collect_steps (int)
batch_size (int)
train_seq_len (int)
imagine_horizon (int)
use_disc_model (bool)
free_nats (float)
discount (float)
td_lambda (float)
kl_loss_coeff (float)
kl_alpha (float)
disc_loss_coeff (float)
num_buckets (int)
symlog_range (float)
model_learning_rate (float)
actor_learning_rate (float)
value_learning_rate (float)
adam_epsilon (float)
grad_clip_norm (float)
use_amp (bool)
test (bool)
test_interval (int)
test_episodes (int)
scalar_freq (int)
log_video_freq (int)
max_videos_to_save (int)
video_format (str)
video_fps (int)
checkpoint_interval (int)
checkpoint_path (str)
restore (bool)
experience_replay (str)
render (bool)
enable_wandb (bool)
wandb_api_key (str)
wandb_project (str)
wandb_entity (str)
log_dir (str)
logdir (Any)
data_dir (Any)
log_level (str)
log_file (Any)
enable_tensorboard (bool)
enable_console_metrics (bool)
enable_jsonl (bool)
jsonl_filename (str)
log_system_stats_freq (int)
detect_anomaly (bool)
- env_backend: str = 'dmc'#
- env: str = 'walker-walk'#
- env_instance: Any = None#
- image_size: tuple[int, int] = (64, 64)#
- gym_render_mode: str = 'rgb_array'#
- dmlab_action_repeat: int = 4#
- dmlab_action_set: Any = None#
- dmlab_observations: Any = None#
- dmlab_config: Any = None#
- dmlab_renderer: str = 'hardware'#
- procgen_distribution_mode: str = 'easy'#
- procgen_num_levels: int = 0#
- procgen_start_level: Any = None#
- mujoco_xml_path: Any = None#
- mujoco_xml_string: Any = None#
- mujoco_binary_path: Any = None#
- mujoco_camera: Any = None#
- mujoco_frame_skip: int = 1#
- mujoco_reset_noise_scale: float = 0.0#
- brax_backend: str = 'generalized'#
- brax_jit: bool = True#
- brax_auto_reset: bool = False#
- brax_suppress_warp_warnings: bool = True#
- unity_file_name: Any = None#
- unity_behavior_name: Any = None#
- unity_worker_id: int = 0#
- unity_base_port: int = 5005#
- unity_no_graphics: bool = True#
- unity_time_scale: float = 20.0#
- unity_quality_level: int = 1#
- algo: str = 'Dreamerv1'#
- exp_name: str = 'lr1e-3'#
- train: bool = True#
- evaluate: bool = False#
- seed: int = 1#
- no_gpu: bool = False#
- max_episode_length: int = 1000#
- buffer_size: int = 800000#
- time_limit: int = 1000#
- cnn_activation_function: str = 'relu'#
- dense_activation_function: str = 'elu'#
- obs_embed_size: int = 1024#
- num_units: int = 400#
- deter_size: int = 200#
- stoch_size: int = 30#
- action_repeat: int = 2#
- action_noise: float = 0.3#
- total_steps: int = 5000000#
- seed_steps: int = 5000#
- update_steps: int = 100#
- collect_steps: int = 1000#
- batch_size: int = 50#
- train_seq_len: int = 50#
- imagine_horizon: int = 15#
- use_disc_model: bool = False#
- free_nats: float = 3.0#
- discount: float = 0.99#
- td_lambda: float = 0.95#
- kl_loss_coeff: float = 1.0#
- kl_alpha: float = 0.8#
- disc_loss_coeff: float = 10.0#
- num_buckets: int = 255#
- symlog_range: float = 10.0#
- model_learning_rate: float = 0.0006#
- actor_learning_rate: float = 8e-05#
- value_learning_rate: float = 8e-05#
- adam_epsilon: float = 1e-07#
- grad_clip_norm: float = 100.0#
- use_amp: bool = True#
- test: bool = False#
- test_interval: int = 10000#
- test_episodes: int = 10#
- scalar_freq: int = 1000#
- log_video_freq: int = -1#
- max_videos_to_save: int = 2#
- video_format: str = 'gif'#
- video_fps: int = 20#
- checkpoint_interval: int = 10000#
- checkpoint_path: str = ''#
- restore: bool = False#
- experience_replay: str = ''#
- render: bool = False#
- enable_wandb: bool = False#
- wandb_api_key: str = ''#
- wandb_project: str = 'torchwm'#
- wandb_entity: str = ''#
- log_dir: str = 'runs'#
- logdir: Any = None#
- data_dir: Any = None#
- log_level: str = 'INFO'#
- log_file: Any = None#
- enable_tensorboard: bool = False#
- enable_console_metrics: bool = True#
- enable_jsonl: bool = True#
- jsonl_filename: str = 'metrics.jsonl'#
- log_system_stats_freq: int = 1000#
- detect_anomaly: bool = False#
- class torchwm.JEPAConfig[source]#
Bases:
SerializableConfigMixinMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- classmethod from_dict(values)[source]#
Load flat field values or the nested trainer dictionary.
- Parameters:
values (Dict[str, Any])
- Return type:
- class torchwm.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
SerializableConfigMixinDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via
__getattr__andget_dit_config().- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- torchwm.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
Both UPPER_CASE and snake_case override keys are accepted.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)
- Parameters:
overrides (Any)
- Return type:
- class torchwm.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
SerializableConfigMixin- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
data_loader_num_workers (int)
pin_memory (bool)
persistent_workers (bool)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
use_amp (bool)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- data_loader_num_workers: int = 4#
- pin_memory: bool = True#
- persistent_workers: bool = True#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- use_amp: bool = True#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
- class torchwm.IRISConfig[source]#
Bases:
SerializableConfigMixinConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class torchwm.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class torchwm.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class torchwm.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class torchwm.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
SerializableConfigMixinConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class torchwm.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
SerializableConfigMixinConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class torchwm.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class torchwm.OperatorABC(*, device=None)[source]#
Bases:
Module,ABCStructured base class for inference operators.
Operators use a consistent pipeline:
preprocessconverts raw inputs into tensors.forwardperforms model/operator-specific tensor computation.postprocessformats the final output mapping.
Subclasses may also declare
input_specsandoutput_specsto validate required tensor keys, shapes, and dtypes.OperatorABCinherits fromtorch.nn.Module, so operators supportto(device),train(), andeval()just like model modules.- Parameters:
device (torch.device | str | None)
- input_specs: Mapping[str, TensorSpec] = {}#
- output_specs: Mapping[str, TensorSpec] = {}#
- abstractmethod preprocess(inputs)[source]#
Convert raw inputs into a tensor mapping ready for
forward.- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- forward(inputs)[source]#
Run tensor computation for this operator.
Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.
- Parameters:
inputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- postprocess(outputs)[source]#
Format validated forward outputs for consumers.
- Parameters:
outputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- process(inputs)[source]#
Process raw inputs through preprocess, forward, and postprocess stages.
- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- batch(inputs)[source]#
Preprocess a sequence of inputs and stack matching tensor keys.
- Parameters:
inputs (Sequence[Any])
- Return type:
dict[str, Tensor]
- to(*args, **kwargs)[source]#
Move module parameters/buffers and remember the target tensor device.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
- classmethod validate_mapping(values, specs, *, label)[source]#
Validate tensor keys, shapes, and dtypes against optional specs.
- Parameters:
values (Mapping[str, Tensor])
specs (Mapping[str, TensorSpec])
label (str)
- Return type:
None
- class torchwm.TensorSpec(shape=None, dtype=None, required=True)[source]#
Bases:
objectOptional tensor contract used to validate operator inputs or outputs.
- Parameters:
shape (tuple[int | None, ...] | None) – Expected shape. Use
Noneas a wildcard for dimensions that may vary, such as batch size.dtype (dtype | None) – Expected tensor dtype.
required (bool) – Whether the key must be present in the mapping being validated.
- shape: tuple[int | None, ...] | None = None#
- dtype: dtype | None = None#
- required: bool = True#
- class torchwm.DreamerOperator(image_size=64, action_dim=6)[source]#
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- class torchwm.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
- class torchwm.IrisOperator(seq_length=512, vocab_size=32000)[source]#
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- class torchwm.PlaNetOperator(state_dim=32, action_dim=4)[source]#
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- torchwm.get_operator(name, **kwargs)[source]#
Factory function to get inference operators by name.
- Parameters:
name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’
**kwargs (Any) – Operator-specific configuration
- Returns:
Configured OperatorABC instance
- Return type:
Example
>>> op = get_operator('dreamer', image_size=64, action_dim=6) >>> processed = op.process({'image': image, 'action': action})
- class torchwm.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModulePredict scalar rewards from Dreamer latent belief and state vectors.
Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.
- Parameters:
belief_size (int)
state_size (int)
hidden_size (int)
activation_function (str)
- class torchwm.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModuleEstimate scalar value from Dreamer latent belief and state vectors.
This MLP is trained on imagined returns and used for actor/value updates.
- Parameters:
belief_size (int)
state_size (int)
hidden_size (int)
activation_function (str)
- torchwm.DreamerRewardModel#
alias of
RewardModel
- torchwm.DreamerValueModel#
alias of
ValueModel
Primary modules: world_models, world_models.models, world_models.configs, world_models.catalog, world_models.envs, and world_models.inference.
TorchWM public API.
This package keeps imports lightweight while still exposing a friendly top-level surface. Common workflows can use the small factory helpers:
import torchwm
cfg = torchwm.create_config("dreamer", env="walker-walk")
agent = torchwm.create_model("dreamer", cfg)
env = torchwm.make_env("CartPole-v1", backend="gym")
Lower-level research components remain available as lazy top-level exports, for
example from torchwm import DreamerAgent, ConvEncoder, ReplayBuffer.
- class world_models.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing an environment backend available through
make_env.- Parameters:
name (str)
factory_path (str)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- factory_path: str#
Alias for field number 1
- description: str#
Alias for field number 2
- aliases: tuple[str, ...]#
Alias for field number 3
- class world_models.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing a model available through
create_model().- Parameters:
name (str)
import_path (str)
config_path (str | None)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- import_path: str#
Alias for field number 1
- config_path: str | None#
Alias for field number 2
- description: str#
Alias for field number 3
- aliases: tuple[str, ...]#
Alias for field number 4
- world_models.create_config(model, **overrides)[source]#
Create the default config object for
modeland apply overrides.Examples
>>> cfg = create_config("dreamer", env="walker-walk", seed=7) >>> cfg.env 'walker-walk'
- Parameters:
model (str)
overrides (Any)
- Return type:
Any
- world_models.create_model(model, config=None, **overrides)[source]#
Instantiate a model or agent from a simple string name.
configis optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.Examples
>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000) >>> genie = create_model("genie-small", image_size=32)
- Parameters:
model (str)
config (Any | None)
overrides (Any)
- Return type:
Any
- world_models.get_env_backend_spec(name)[source]#
Return metadata for an environment backend name or alias.
- Parameters:
name (str)
- Return type:
- world_models.get_model_spec(name)[source]#
Return metadata for a model name or alias.
- Parameters:
name (str)
- Return type:
- world_models.list_env_backends()[source]#
Return canonical backend names accepted by
make_env().- Return type:
list[str]
- world_models.list_envs(model=None)[source]#
List known environment ids, optionally filtered by model family.
- Parameters:
model (str | None)
- Return type:
list[str] | dict[str, list[str]]
- world_models.list_models()[source]#
Return canonical model names accepted by
create_model().- Return type:
list[str]
- world_models.make_env(env_id, backend='auto', **kwargs)[source]#
Create an environment with a consistent TorchWM entry point.
- Parameters:
env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.
backend (str) – One of
list_env_backends();"auto"tries TorchWM’s compatibility helper.**kwargs (Any) – Backend-specific options.
- Return type:
Any
- world_models.export_any(obj, path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export any TorchWM model/agent or a target module contained by it.
- Parameters:
obj (Any)
path (str | Path)
format (str)
example_inputs (Any | None)
target (str | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- world_models.export_model(module, path, format='onnx', *, example_inputs=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export a
torch.nn.Moduleto ONNX, TorchScript, or TensorRT.- Parameters:
module (Module)
path (str | Path)
format (str)
example_inputs (Any | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- class world_models.ExportableAgentMixin[source]#
Bases:
objectMixin for non-
nn.Moduleagents that delegates to the shared exporter.- export(path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export this agent or one of its contained modules for deployment.
- Parameters:
path (str | Path)
format (str)
example_inputs (Any | None)
target (str | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- class world_models.IRISAgent(config, action_size, device)[source]#
Bases:
ModuleComplete IRIS Agent with world model and policy.
Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning
- Parameters:
config (IRISConfig)
action_size (int)
device (device)
- classmethod from_config(config=None, *, action_size, device=None, **overrides)[source]#
Build an IRIS agent from a config object, dict, YAML file, or YAML string.
- Parameters:
config (IRISConfig | dict[str, Any] | str | Path | None)
action_size (int)
device (device | str | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, action_size=None, device=None, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#
Load an IRIS agent checkpoint from a local path/directory or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
action_size (int | None)
device (device | str | None)
config (IRISConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
overrides (Any)
- Return type:
- forward_actor_critic(frames, hidden=None)[source]#
Forward pass through actor-critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state
- Returns:
(B, T, action_size) values: (B, T) hidden_state: (h, c)
- Return type:
action_logits
- act(frame, epsilon=0.0, temperature=1.0)[source]#
Sample action from policy.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
epsilon (float) – Random action probability
temperature (float) – Action distribution temperature
- Returns:
Selected actions (B,)
- Return type:
- imagine_rollout(initial_frame, horizon=20)[source]#
Generate imagined trajectories using world model.
- Parameters:
initial_frame (Tensor) – Starting frame (B, C, H, W)
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined rollout data
- Return type:
trajectory
- update_autoencoder(frames)[source]#
Update discrete autoencoder.
- Parameters:
frames (Tensor) – Training frames (B, C, H, W)
- Returns:
Dictionary of loss values
- Return type:
losses
- update_transformer(frames, actions, rewards, terminals)[source]#
Update transformer world model.
- Parameters:
frames (Tensor) – Frame sequence
actions (Tensor) – Actions taken
rewards (Tensor) – Rewards received
terminals (Tensor) – Terminal flags
- Returns:
Dictionary of loss values
- Return type:
losses
- world_models.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#
Compute λ-return target for value function training.
- Parameters:
rewards (Tensor) – Rewards (B, T)
values (Tensor) – Value estimates (B, T+1)
discounts (Tensor) – Discount factors (B, T)
lambda_coef (float) – Lambda parameter for bootstrapping
- Returns:
λ-return targets (B, T)
- Return type:
- class world_models.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#
Bases:
ModuleModular RSSM with swappable encoder, decoder, and backbone.
This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer
Example
>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024) >>> decoder = ConvDecoder(32, 200, (3, 64, 64)) >>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024) >>> rssm = ModularRSSM(encoder, decoder, backbone)
- Parameters:
encoder (EncoderBase)
decoder (DecoderBase)
backbone (BackboneBase)
reward_decoder (DecoderBase | None)
- property stoch_size: int#
- property deter_size: int#
- property embed_size: int#
- init_state(batch_size, device)[source]#
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
obs (Tensor)
nonterm (Any)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
nonterm (Any)
- Return type:
Dict[str, Tensor]
- observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
- Parameters:
obs (Tensor)
actions (Tensor)
nonterms (Tensor)
prev_state (Dict[str, Tensor])
horizon (int)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- world_models.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#
Factory function to create a modular RSSM with specified components.
- Parameters:
encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)
decoder_type (str) – Type of decoder (“conv”, “mlp”)
backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)
obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state
action_size (int) – Action space dimension
stoch_size (int) – Stochastic latent dimension
deter_size (int) – Deterministic hidden dimension
embed_size (int) – Encoder embedding dimension
hidden_size (int) – Hidden layer dimension
activation (str) – Activation function name
kwargs (Any)
- Returns:
Configured ModularRSSM instance
- Return type:
- class world_models.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleGenie: Generative Interactive Environment.
A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions
Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).
Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)
The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse
At inference, latent actions are stopgrad’d when passed to dynamics model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_decoder_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
encoder_depth (int)
decoder_depth (int)
latent_action_depth (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- classmethod from_config(config=None, **overrides)[source]#
Build Genie from a config object, dict, YAML file, or YAML string.
- Parameters:
config (GenieConfig | GenieSmallConfig | dict[str, Any] | str | Path | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Load Genie weights from a local path/directory or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
config (GenieConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
map_location (str | device | None)
overrides (Any)
- Return type:
- save_pretrained(path)[source]#
Save Genie weights and config in a from_pretrained-compatible format.
- Parameters:
path (str | Path)
- Return type:
None
- forward(video, mask_prob=0.5, training_phase='all')[source]#
Full forward pass through all components.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing losses and predictions
- Return type:
Dict[str, Tensor]
- training_step(video, mask_prob=0.5, training_phase='all')[source]#
Single training step computing all losses.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing all losses for backpropagation
- Return type:
Dict[str, Tensor]
- encode_video(video)[source]#
Encode video to discrete tokens.
- Parameters:
video (Tensor) – (B, C, T, H, W)
- Returns:
(B, T, H*W)
- Return type:
video_tokens
- infer_actions(frames)[source]#
Infer latent actions from a sequence of frames.
- Parameters:
frames (Tensor) – (B, C, T, H, W) video frames
- Returns:
(B, T-1) inferred latent action indices
- Return type:
latent_actions
- generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#
Generate video frames given a prompt frame and actions.
- Parameters:
prompt_frame (Tensor) – (B, C, H, W) initial frame
num_frames (int) – Total number of frames to generate
actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random
use_maskgit (bool) – Whether to use MaskGIT sampling
- Returns:
(B, C, num_frames, H, W)
- Return type:
generated_video
- play(current_frame, action, current_frames=None)[source]#
Play step - generate next frame given current frame and action.
- Parameters:
current_frame (Tensor) – (B, C, H, W) current frame
action (Tensor) – (B,) latent action indices
current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame
- Returns:
(B, C, H, W)
- Return type:
next_frame
- class world_models.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleLatent Action Model (LAM) for unsupervised action learning.
Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.
Based on Genie paper - learns actions without action labels from Internet videos.
Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- encode(x_prev, x_next)[source]#
Encode frames to latent actions.
- Parameters:
x_prev (Tensor) – Previous frames (B, C, T, H, W)
x_next (Tensor) – Next frame (B, C, H, W)
- Returns:
Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)
- Return type:
latent_actions
- class world_models.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=True)[source]#
Bases:
ModuleDynamics Model for action-controllable video generation.
A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.
Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
gradient_checkpointing (bool)
- forward(video_tokens, actions, mask_prob=0.0)[source]#
Forward pass for training.
- Parameters:
video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T
actions (Tensor) – (B, T) - latent action indices for frames 1 to T
mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)
- Returns:
(B, T, H*W, vocab_size)
- Return type:
logits
- sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#
Sample future frames using MaskGIT.
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
sampler (MaskGITSampler | None) – MaskGIT sampler instance
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#
Simple autoregressive sampling (token by token).
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
temperature (float) – Sampling temperature
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- world_models.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Create a smaller Genie model for development/testing.
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#
Create the full 11B parameter Genie model (approximate).
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#
Factory function to create a Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
- Return type:
- class world_models.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#
Bases:
ModuleRecurrent State-Space Model used by Dreamer for latent dynamics learning.
The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:
Deterministic State (h) – A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.
Stochastic State (s) – A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).
The model operates in two modes:
Observe Mode – Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)
Imagine Mode – Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)
Architecture
Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1}
Process: GRU updates deterministic state, MLP computes stochastic prior/posterior
Output: Updated state (h_t, s_t) and distributions
State Representation
deter (h): GRU hidden state, captures sequential context
stoch (s): Stochastic latent, multi-modal uncertainty
mean/std: Parameters of the stochastic distribution
Usage with DreamerAgent:
rssm = RSSM( action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation='elu' ) # Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed) # Imagine future without observation prior = rssm.imagine_step(current_state, action)
Training
The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to
capture environment dynamics
Reconstruction loss from decoder ensures state captures observation info
- Reference:
Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
hidden_size (int)
obs_embed_size (int)
activation (str)
- init_state(batch_size, device)[source]#
Initialize RSSM state with zeros.
- Parameters:
batch_size (int) – Number of parallel sequences
device (device) – torch device for tensors
- Returns:
mean, std: Stochastic distribution parameters
stoch: Stochastic state sample
deter: Deterministic GRU hidden state
- Return type:
Dictionary containing zero-initialized state components
- get_dist(mean, std)[source]#
Create an Independent Normal distribution from mean and std.
- Parameters:
mean (Tensor) – Location parameter
std (Tensor) – Scale parameter
- Returns:
Independent Normal distribution with given parameters
- Return type:
Independent
- observe_step(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#
Update state using actual observation (observe mode).
In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.
- Parameters:
prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)
obs_embed (Tensor) – Observation embedding from encoder, shape (B, obs_embed_size)
nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
A tuple
(posterior, prior)of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.- Return type:
Tuple[dict, dict]
- imagine_step(prev_state, prev_action, nonterm=tensor(1.))[source]#
Predict next state without observation (imagine mode).
In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.
- Parameters:
prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)
nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
deter: Predicted deterministic state
mean, std, stoch: Prior stochastic state distribution
- Return type:
Dictionary with predicted state containing
- get_prior(prev_state, prev_action, nonterm=tensor(1.))[source]#
- Parameters:
prev_state (dict)
prev_action (Tensor)
nonterm (Tensor)
- Return type:
dict
- get_posterior(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#
Compute posterior distribution over stochastic state.
The posterior incorporates observation information to produce a more accurate state estimate.
- Parameters:
prev_state (dict) – Previous state dictionary
prev_action (Tensor) – Previous action
obs_embed (Tensor) – Observation embedding
nonterm (Tensor) – Termination mask
- Returns:
Dictionary with posterior state (observation-informed). Note that the previous-state shape
(B, ...)is preserved; the batch dimension is not flattened.- Return type:
dict
- detach_state(state)[source]#
Detach state tensors from computation graph.
Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.
- Parameters:
state (dict) – State dictionary with tensor values
- Returns:
Detached state dictionary
- Return type:
dict
- seq_to_batch(state_dict)[source]#
Convert sequence state to batch format.
- Parameters:
state_dict (dict) – Dictionary with sequence-dimension tensors (T, B, …)
- Returns:
Dictionary with batch-dimension tensors (B*T, …)
- Return type:
dict
- observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#
Process a sequence of observations (observe mode rollout).
At each timestep we run
observe_steponce to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.- Parameters:
obs_embed (Tensor) – Observation embeddings, shape (T+1, B, obs_embed_size)
actions (Tensor) – Actions, shape (T, B, action_size)
nonterms (Tensor) – Non-termination flags, shape (T, B, 1)
init_state (dict) – Initial state dictionary
seq_len (int) – Sequence length T
- Returns:
Dictionary with prior states stacked along the time axis posterior: Dictionary with posterior states stacked along the time axis
- Return type:
prior
- imagine_rollout(policy, init_state, horizon)[source]#
Generate imagined trajectory using policy (imagine mode rollout).
- Parameters:
policy (Module) – Actor network that outputs actions from state features
init_state (dict) – Initial state dictionary
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined states for each step
- Return type:
dict
- forward(x, u)[source]#
Forward pass for training (computes sequence of states).
- Parameters:
x (Tensor) – Observations, shape (B, T+1, C, H, W)
u (Tensor) – Actions, shape (B, T, action_size)
- Returns:
List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)
- Return type:
states
- class world_models.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#
Bases:
ModuleA Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.
- Parameters:
action_size (int)
state_size (int)
latent_size (int)
hidden_size (int)
embed_size (int)
activation_function (str)
- get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#
Returns the initial posterior given the observation.
- Parameters:
enc (Tensor)
h_t (Tensor | None)
s_t (Tensor | None)
a_t (Tensor | None)
mean (bool)
- Return type:
tuple[Tensor, Tensor]
- deterministic_state_fwd(h_t, s_t, a_t)[source]#
Deterministic transition update.
Ensures a_t is 2D and matches batch dimension of h_t before concatenation. Accepts a_t shaped [B, action_size], [action_size] (expanded to [B, action_size]), or [B]/scalar (reshaped appropriately).
- Parameters:
h_t (Tensor)
s_t (Tensor)
a_t (Tensor)
- Return type:
Tensor
- state_prior(h_t, sample=False)[source]#
Returns the prior distribution over the latent state given the deterministic state
- Parameters:
h_t (Tensor)
sample (bool)
- Return type:
tuple[Tensor, Tensor] | Tensor
- state_posterior(h_t, e_t, sample=False)[source]#
Returns the state prior given the deterministic state and obs
- Parameters:
h_t (Tensor)
e_t (Tensor)
sample (bool)
- Return type:
tuple[Tensor, Tensor] | Tensor
- rollout_prior(act, h_t, s_t)[source]#
- Parameters:
act (Tensor)
h_t (Tensor)
s_t (Tensor)
- Return type:
tuple[Tensor, Tensor]
- forward(x, u)[source]#
Forward through the RSSM for a batch of sequences.
- Parameters:
x (Tensor) – Tensor [B, T+1, C, H, W] (observations including initial frame)
u (Tensor) – Tensor [B, T, action_size] (actions for T steps)
- Returns:
list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]
- Return type:
states
- class world_models.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).
Input: (B, C, H, W) raw images, values in [-0.5, 0.5]
Process: 4 convolutional layers with stride 2, halving spatial dimensions
Output: (B, embed_size) compact representation
The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.
Usage with Dreamer:
encoder = ConvEncoder( input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation='relu' # Activation function ) obs_embedding = encoder(observation) # (B, 256)
- Parameters:
input_shape (tuple) – Tuple (C, H, W) for input images, typically (3, 64, 64)
embed_size (int) – Output embedding dimension, typically 256 or 1024
activation (str) – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)
depth (int) – Base channel depth for first layer (default 32)
- class world_models.CNNEncoder(embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) encoder for processing image inputs.
- Parameters:
embedding_size (int)
activation_function (str)
- class world_models.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#
Bases:
ModuleConvolutional decoder for reconstructing observations from latent states.
Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.
Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter)
Process: Dense projection + 4 transposed convolutions (upsampling 2x each)
Output: Independent Normal distribution over observation pixels
The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.
Returns
torch.distributions.Independent(Normal(mean, std), len(shape))allowing log_prob(observation) computation for reconstruction loss.Usage in Dreamer world model:
decoder = ConvDecoder( stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation='relu' ) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)
The reconstruction loss is
-log_prob(observation), which encourages the RSSM to learn states that capture observation information.- Parameters:
stoch_size (int)
deter_size (int)
output_shape (tuple[int, ...])
activation (str)
depth (int)
- class world_models.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) decoder for reconstructing image outputs.
- Parameters:
state_size (int)
latent_size (int)
embedding_size (int)
activation_function (str)
- class world_models.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.
Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter)
Process: MLP with configurable layers and hidden units
Output: Predicted quantity with distribution (normal, binary, or raw)
Supports three output types: -
'normal': Gaussian distribution for regression (rewards, values) -'binary': Bernoulli distribution for binary classification (discount) -'none': Raw tensor for non-probabilistic outputsUsage:
reward_decoder = DenseDecoder( stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation='elu', dist='normal' ) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)
For discount prediction (binary):
discount_decoder = DenseDecoder( stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation='elu', dist='binary' )
- Parameters:
stoch_size (int)
deter_size (int)
output_shape (tuple[int, ...])
n_layers (int)
units (int)
activation (str)
dist (str)
num_buckets (int)
symlog_range (float)
- class world_models.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
n_layers (int)
units (int)
activation (str)
min_std (float)
init_std (float)
mean_scale (float)
- class world_models.TanhBijector[source]#
Bases:
TransformBijective tanh transform for squashing Gaussian distributions to [-1, 1].
This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:
Bijective mapping: tanh is invertible (with atanh as inverse)
Stable log-det Jacobian: Computable for gradient-based training
Clipped actions: During inference, actions are naturally bounded
Forward: y = tanh(x)
Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y))
Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))
Usage with Dreamer ActionDecoder:
dist = TransformedDistribution( Normal(mean, std), TanhBijector() ) action = dist.sample() # Bounded to [-1, 1]
- Reference:
Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.
- property sign: int#
- class world_models.SampleDist(dist, samples=100)[source]#
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- Parameters:
dist (Any)
samples (int)
- property name: str#
- class world_models.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Encoder for IRIS discrete autoencoder.
Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.
- Architecture:
4 convolutional layers with residual blocks
Self-attention at 8x8 and 16x16 resolutions
Vector quantization to produce discrete tokens
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
in_channels (int)
base_channels (int)
num_residual_blocks (int)
frame_shape (Tuple[int, int, int])
- forward(x)[source]#
Encode images to discrete tokens.
- Parameters:
x (Tensor) – Input images (B, C, H, W) - should be 64x64
- Returns:
Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- class world_models.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Decoder for IRIS discrete autoencoder.
Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.
- Parameters:
vocab_size (int)
embedding_dim (int)
base_channels (int)
out_channels (int)
frame_shape (Tuple[int, int, int])
- forward(z)[source]#
Decode tokens to images.
- Parameters:
z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)
- Returns:
Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)
- Return type:
reconstructed
- class world_models.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#
Bases:
ModuleVideo Tokenizer using VQ-VAE with Spatiotemporal Transformer.
This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.
The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.
Architecture
Patch Embedding: Convert (B, C, T, H, W) video to patch tokens
Encoder ST-Transformer: Process spatial-temporal patches
Vector Quantization: Discretize continuous embeddings to codebook entries
Decoder ST-Transformer: Reconstruct video from quantized tokens
Patch Unembedding: Convert tokens back to video frames
Key Features
Causal processing: Each frame’s encoding only uses previous frames
Discrete tokens: Enables autoregressive prediction with latent actions
Memory efficient: Uses ST-Transformer instead of full ViT to reduce complexity
Usage with Genie:
tokenizer = VideoTokenizer( num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32 ) reconstructed, indices, loss_dict = tokenizer(video_frames) # For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)
The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings - Commitment loss: Penalizes encoder outputs drifting from codebook
- Reference:
Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
use_ema (bool)
ema_decay (float)
- encode(x)[source]#
Encode video to discrete tokens.
- Parameters:
x (Tensor) – Video tensor (B, C, T, H, W)
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- decode_indices(indices)[source]#
Decode token indices to embeddings for video frames.
- Parameters:
indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’ x W’
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim)
- Return type:
z_q
- world_models.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#
Factory function to create a Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
- Return type:
- class world_models.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#
Bases:
ModuleVector Quantizer for discrete autoencoder.
Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)
Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
- class world_models.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#
Bases:
ModuleVector Quantizer with Exponential Moving Average updates.
Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
ema_decay (float)
epsilon (float)
- class world_models.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#
Bases:
objectFixed-size replay buffer for Dreamer with image observations and transitions.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.
Key Features
Ring buffer with fixed capacity (FIFO eviction when full)
Stores raw uint8 images to save memory
Samples sequences (not single transitions) for temporal modeling
Validates sampled sequences don’t span episode boundaries
Memory Layout
observations: (capacity, C, H, W) uint8 images
actions: (capacity, action_dim) float32
rewards: (capacity,) float32
terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)
Sampling Process
Random start index (avoiding episode boundaries)
Collect sequence of length seq_len with wraparound
Validate no terminal in middle of sequence
Return batch of sequences
Usage with Dreamer:
buffer = ReplayBuffer( size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch ) # Add transitions during interaction buffer.add(obs, action, reward, done) # Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample()
Memory Efficiency
Uses uint8 for images (1 byte per pixel vs 4 for float32)
Sequences share observations (overlapping windows)
Configurable capacity based on available system memory
Note
The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.
- Parameters:
size (int)
obs_shape (Tuple[int, ...])
action_size (int)
seq_len (int)
batch_size (int)
- add(obs, ac, rew, done)[source]#
Add a transition to the buffer.
- Parameters:
obs (dict) – Observation dict with ‘image’ key containing the observation
ac (ndarray) – Action taken, shape (action_size,)
rew (float) – Reward received, scalar
done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise
- Return type:
None
- class world_models.Memory(size=None)[source]#
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.
Stores complete episodes as lists of transitions
Samples contiguous sub-sequences for sequence models
Supports time-major formatting (time-first) for RNN input
Memory usage estimation to prevent OOM errors
- Parameters:
size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).
- episodes#
Collection of Episode objects.
- Type:
deque
- eps_lengths#
Length of each episode.
- Type:
deque
- size#
Total number of transitions across all episodes.
- Type:
property
Example:
memory = Memory(size=100) memory.append([episode1, episode2]) batch, lengths = memory.sample(batch_size=32, tracelen=50)
- property size: int#
- sample(batch_size, tracelen=1, time_first=False)[source]#
Sample random sub-sequences from stored episodes.
Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.
- Parameters:
batch_size (int) – Number of sequences to sample.
tracelen (int) – Length of each sequence (default: 1).
time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).
- Returns:
- (observations, actions, rewards, terminals, lengths)
observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)
actions: (batch, tracelen, action_dim) or (tracelen, batch, …)
rewards: (batch, tracelen) or (tracelen, batch)
terminals: (batch, tracelen) or (tracelen, batch)
lengths: (batch,) original episode lengths for each sample
- Return type:
tuple
- Raises:
ValueError – If memory is empty or no episodes meet minimum length.
MemoryError – If estimated memory usage exceeds 200 MiB threshold.
- class world_models.Episode(postprocess_fn=None)[source]#
Bases:
objectRecords the agent’s interaction with the environment for a single episode.
Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.
- x#
Observations collected during the episode.
- Type:
list or np.ndarray
- u#
Actions taken.
- Type:
list or np.ndarray
- r#
Rewards received.
- Type:
list or np.ndarray
- t#
Terminal flags (0.0 = continue, 1.0 = terminal).
- Type:
list or np.ndarray
- info#
Additional episode metadata.
- Type:
dict
- Parameters:
postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.
Example:
episode = Episode() episode.append(obs, action, reward, False) episode.append(obs, action, reward, True) episode.terminate(final_obs) print(episode.x.shape) # Now a numpy array
- property size: int#
- class world_models.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#
Bases:
objectReplay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.
- Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores uint8 images for memory efficiency
Samples sequences with validation to avoid episode boundaries
Supports sequence sampling for temporal learning
- Memory Layout:
observations: (capacity, C, H, W) uint8
actions: (capacity, action_size) float32
rewards: (capacity,) float32
terminals: (capacity,) float32
- Parameters:
size (int) – Maximum number of transitions to store.
obs_shape (tuple) – Shape of observations as (C, H, W).
action_size (int) – Dimension of actions.
seq_len (int) – Length of sequences to sample (default: 20).
batch_size (int) – Number of sequences per batch (default: 64).
- size#
Buffer capacity.
- Type:
int
- obs_shape#
Observation shape.
- Type:
tuple
- action_size#
Action dimension.
- Type:
int
- seq_len#
Sequence length.
- Type:
int
- batch_size#
Batch size.
- Type:
int
- steps#
Total transitions added.
- Type:
int
- episodes#
Number of episode terminations observed.
- Type:
int
- add(obs, action, reward, terminal)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray) – Observation array with shape (C, H, W).
action (ndarray) – Action array with shape (action_size,).
reward (float) – Scalar reward value.
terminal (bool) – Boolean indicating if episode terminated.
- Return type:
None
- sample_sequence(seq_len=None)[source]#
Sample a batch of sequences for world model training.
- Returns:
(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)
- Return type:
- Parameters:
seq_len (int | None)
- sample_single()[source]#
Sample a single transition for online updates.
- Return type:
Tuple[ndarray, ndarray, float, float]
- property buffer_capacity: int#
Returns the total capacity of the buffer.
- class world_models.IRISOnPolicyBuffer(max_steps=1000)[source]#
Bases:
objectOn-policy buffer for collecting trajectories during environment interaction.
Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.
- Useful for:
Collecting complete episode trajectories
Storing data before batch processing
Temporary storage during environment interaction
- Parameters:
max_steps (int) – Maximum number of steps to store (default: 1000).
- max_steps#
Maximum buffer capacity.
- Type:
int
- observations#
List of observations.
- Type:
list
- actions#
List of actions.
- Type:
list
- rewards#
List of rewards.
- Type:
list
- terminals#
List of terminal flags.
- Type:
list
- class world_models.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- Parameters:
img_size (int)
patch_size (int)
in_channels (int)
d_model (int)
depth (int)
heads (int)
drop (float)
t_dim (int)
- classmethod from_config(config=None, **overrides)[source]#
Build DiT from a config object, dict, YAML file, or YAML string.
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Load DiT weights from a local path/directory or HF Hub.
- save_pretrained(path)[source]#
Save DiT weights and config in a from_pretrained-compatible format.
- Parameters:
path (str | Path)
- Return type:
None
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
- Parameters:
epochs (int)
dataset (Any)
batch_size (int)
lr (float)
img_size (int)
channels (int)
patch (int)
width (int)
depth (int)
heads (int)
drop (float)
timesteps (int)
beta_start (float)
beta_end (float)
ema (bool)
ema_decay (float)
workdir (str)
root_path (str)
image_folder (str | None)
crop_size (int)
download (bool)
copy_data (bool)
subset_file (str | None)
val_split (float | None)
- Return type:
None
- world_models.create_dit(config=None, **overrides)[source]#
Create a
DiTfrom aDiTConfigor keyword overrides.The public factory API works with config objects, while
DiTitself has a compact constructor. This adapter keeps the lower-level model constructor unchanged and maps the public config fields onto the expected arguments.- Parameters:
config (Any)
overrides (Any)
- Return type:
- class world_models.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.
- Process:
Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches
Each patch is projected to embed_dim via linear layer (Conv2d)
Learnable positional embeddings are added for spatial information
Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)
- Parameters:
img_size (int) – Image size (assumes square), e.g., 32 for CIFAR
patch_size (int) – Size of each patch (typically 4, 8, or 16)
in_channels (int) – Number of input channels (3 for RGB)
embed_dim (int) – Output dimension for each patch token
- Usage with DiT:
patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4
- class world_models.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- Parameters:
img_size (int)
patch_size (int)
embed_dim (int)
out_channels (int)
- class world_models.DDPM(timesteps, beta_start, beta_end)[source]#
Bases:
ModuleUtility module implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- Parameters:
timesteps (int)
beta_start (float)
beta_end (float)
- q_sample(x_start, t, noise=None)[source]#
- Parameters:
x_start (Tensor)
t (Tensor)
noise (Tensor | None)
- Return type:
Tensor
- class world_models.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#
Bases:
ModuleActor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.
- Parameters:
obs_channels (int)
action_dim (int)
channels (Tuple[int, ...])
lstm_dim (int)
- forward(obs, hidden_state=None)[source]#
Forward pass of actor-critic network.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)
- Return type:
policy_logits
- get_action(obs, hidden_state=None, deterministic=False)[source]#
Get action from a single observation.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
deterministic (bool) – If True, take argmax; else sample
- Returns:
Selected action [B] hidden_state: (h, c)
- Return type:
action
- get_actions(obs, hidden_state=None, deterministic=False)[source]#
Batched version of get_action.
- Parameters:
obs (Tensor) – Tensor of shape [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size
deterministic (bool) – If True, take argmax; else sample from policy
- Returns:
LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple
- Return type:
- get_value(obs, hidden_state=None)[source]#
Get value for a single observation.
- Parameters:
obs (Tensor)
hidden_state (Tuple[Tensor, Tensor] | None)
- Return type:
Tuple[Tensor, Tuple[Tensor, Tensor] | None]
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
Get LSTM hidden size.
- Return type:
int
- class world_models.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#
Bases:
ModuleReward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.
- Parameters:
obs_channels (int) – Number of observation channels (3 for RGB)
action_dim (int) – Number of possible actions
channels (Tuple[int, ...]) – List of channel sizes for conv blocks
lstm_dim (int) – LSTM hidden dimension
cond_dim (int) – Conditioning dimension for adaptive norm
- forward(obs, actions, hidden_state=None)[source]#
Forward pass of reward/termination model.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
actions (Tensor) – Actions [B, T]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states
- Return type:
reward_logits
- predict(obs, actions, hidden_state=None)[source]#
Predict reward and termination for a single step.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
actions (Tensor) – Single action [B]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states
- Return type:
reward
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- world_models.sinusoidal_time_embedding(timesteps, dim)[source]#
Create sinusoidal timestep embeddings for diffusion conditioning.
This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.
- Math:
embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)
- Parameters:
timesteps (Tensor) – Tensor of integer timesteps, shape (B,) or (B, 1)
dim (int) – Embedding dimension (must be even)
- Returns:
Tensor of shape (B, dim) with sinusoidal embeddings
- Return type:
Tensor
- Usage with DiT:
t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)
# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization
- class world_models.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, gradient_checkpointing=False)[source]#
Bases:
ModuleSpatiotemporal Transformer for video modeling.
Contains L spatiotemporal blocks with interleaved spatial and temporal attention.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
gradient_checkpointing (bool)
- class world_models.MultiHeadSelfAttention(d, n_heads=2)[source]#
Bases:
ModuleMulti-head scaled dot-product self-attention over sequence tokens.
This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.
- Parameters:
d (int)
n_heads (int)
- world_models.MultiHeadAttention#
alias of
MultiHeadSelfAttention
- class world_models.AdaLNNormalization(d_model, t_dim)[source]#
Bases:
ModuleAdaptive layer normalization conditioned on an external embedding.
The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings).
- Parameters:
d_model (int)
t_dim (int)
- class world_models.RMSNorm(dim, eps=1e-06)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization with a learned gain parameter.
RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm.
- Parameters:
dim (int)
eps (float)
- class world_models.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#
Bases:
objectModel-predictive controller using Cross-Entropy Method (CEM) with RSSM.
Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.
The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.
- Algorithm:
Initialize Gaussian distribution over action sequences
Sample N candidate action sequences
Rollout each sequence in RSSM latent space
Score by predicted cumulative rewards
Keep top K candidates, fit Gaussian to them
Repeat for T iterations
Execute first action from best sequence
- Parameters:
model (Any)
planning_horizon (int)
num_candidates (int)
num_iterations (int)
top_candidates (int)
device (device | str)
- rssm#
The RSSM world model.
- N#
Number of candidate action sequences to sample.
- K#
Number of top candidates to use for updating the proposal.
- T#
Number of CEM iterations per planning step.
- H#
Planning horizon (number of future steps to consider).
- d#
Action dimensionality.
- device#
Device to run computations on.
- state_size#
Hidden state dimensionality.
- latent_size#
Latent state dimensionality.
Example
>>> policy = RSSMPolicy( ... model=rssm, ... planning_horizon=12, ... num_candidates=1000, ... num_iterations=5, ... top_candidates=100, ... device='cuda' ... ) >>> policy.reset() >>> action = policy.poll(observation)
- reset()[source]#
Reset the policy state.
Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.
- Return type:
None
- poll(observation, explore=False)[source]#
Get action for given observation.
- Parameters:
observation (Tensor) – Current observation tensor of shape (channels, height, width).
explore (bool) – If True, add exploration noise to the selected action.
- Returns:
Action tensor of shape (1, action_size).
- Return type:
Tensor
- class world_models.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleActor network for IRIS (Imagined Rollouts with Implicit Successor) policy.
Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.
- Architecture:
CNN: Extracts features from input frames (3x64x64 -> 512)
LSTM: Processes temporal sequences with configurable layers
Linear: Maps hidden states to action logits
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
- action_size#
Number of discrete actions.
- Type:
int
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- forward(frames, hidden_state=None, burn_in_frames=None)[source]#
Forward pass through actor.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state
burn_in_frames (Tensor | None) – Frames to use for initializing hidden state
- Returns:
Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple
- Return type:
action_logits
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- get_action(frame, temperature=1.0, deterministic=False)[source]#
Get action from a single frame.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
temperature (float) – Softmax temperature (higher = more random)
deterministic (bool) – If True, return argmax; else sample
- Returns:
Selected action indices (B,)
- Return type:
action
- class world_models.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCritic network for IRIS value estimation.
Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.
- Architecture:
CNN: Shared feature extractor with actor (3x64x64 -> 512)
LSTM: Temporal processing with same architecture as actor
Linear: Maps hidden states to scalar values
- Parameters:
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- Returns:
Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.
- Return type:
values
- Parameters:
hidden_size (int)
num_layers (int)
frame_shape (Tuple[int, int, int])
- forward(frames, hidden_state=None)[source]#
Forward pass through critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple
- Returns:
Value estimates (B, T) hidden_state: Updated (h, c) tuple
- Return type:
values
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class world_models.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCombined policy module for IRIS (Imagined Rollouts with Implicit Successor).
Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
Example
>>> policy = IRISPolicy( ... action_size=18, ... hidden_size=512, ... num_layers=4, ... frame_shape=(3, 64, 64) ... ) >>> action = policy.act(frame, temperature=1.0, deterministic=False)
- forward(frames)[source]#
Get action logits from frames.
- Parameters:
frames (Tensor)
- Return type:
Tensor
- act(frame, temperature=1.0, deterministic=False)[source]#
Sample action from policy.
- Parameters:
frame (Tensor)
temperature (float)
deterministic (bool)
- Return type:
Tensor
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
tuple[Tensor, Tensor]
- class world_models.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#
Bases:
ModuleCNN feature extractor shared between actor and critic networks.
Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.
- Architecture:
Conv layers: 32 -> 64 -> 128 -> 256 channels
Each conv has stride=2 for spatial downsampling
Final linear layer projects to desired output dimension
- Parameters:
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
output_size (int) – Size of output feature vector (default: 512).
- frame_shape#
Input frame shape.
- Type:
tuple
- output_size#
Output feature dimension.
- Type:
int
- Returns:
Feature vectors with shape (B, output_size).
- Return type:
features
- Parameters:
frame_shape (Tuple[int, int, int])
output_size (int)
- class world_models.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#
Bases:
SerializableConfigMixinConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- Parameters:
env_backend (str)
env (str)
env_instance (Any)
image_size (tuple[int, int])
gym_render_mode (str)
dmlab_action_repeat (int)
dmlab_action_set (Any)
dmlab_observations (Any)
dmlab_config (Any)
dmlab_renderer (str)
procgen_distribution_mode (str)
procgen_num_levels (int)
procgen_start_level (Any)
mujoco_xml_path (Any)
mujoco_xml_string (Any)
mujoco_binary_path (Any)
mujoco_camera (Any)
mujoco_frame_skip (int)
mujoco_reset_noise_scale (float)
brax_backend (str)
brax_jit (bool)
brax_auto_reset (bool)
brax_suppress_warp_warnings (bool)
unity_file_name (Any)
unity_behavior_name (Any)
unity_worker_id (int)
unity_base_port (int)
unity_no_graphics (bool)
unity_time_scale (float)
unity_quality_level (int)
algo (str)
exp_name (str)
train (bool)
evaluate (bool)
seed (int)
no_gpu (bool)
max_episode_length (int)
buffer_size (int)
time_limit (int)
cnn_activation_function (str)
dense_activation_function (str)
obs_embed_size (int)
num_units (int)
deter_size (int)
stoch_size (int)
action_repeat (int)
action_noise (float)
total_steps (int)
seed_steps (int)
update_steps (int)
collect_steps (int)
batch_size (int)
train_seq_len (int)
imagine_horizon (int)
use_disc_model (bool)
free_nats (float)
discount (float)
td_lambda (float)
kl_loss_coeff (float)
kl_alpha (float)
disc_loss_coeff (float)
num_buckets (int)
symlog_range (float)
model_learning_rate (float)
actor_learning_rate (float)
value_learning_rate (float)
adam_epsilon (float)
grad_clip_norm (float)
use_amp (bool)
test (bool)
test_interval (int)
test_episodes (int)
scalar_freq (int)
log_video_freq (int)
max_videos_to_save (int)
video_format (str)
video_fps (int)
checkpoint_interval (int)
checkpoint_path (str)
restore (bool)
experience_replay (str)
render (bool)
enable_wandb (bool)
wandb_api_key (str)
wandb_project (str)
wandb_entity (str)
log_dir (str)
logdir (Any)
data_dir (Any)
log_level (str)
log_file (Any)
enable_tensorboard (bool)
enable_console_metrics (bool)
enable_jsonl (bool)
jsonl_filename (str)
log_system_stats_freq (int)
detect_anomaly (bool)
- env_backend: str = 'dmc'#
- env: str = 'walker-walk'#
- env_instance: Any = None#
- image_size: tuple[int, int] = (64, 64)#
- gym_render_mode: str = 'rgb_array'#
- dmlab_action_repeat: int = 4#
- dmlab_action_set: Any = None#
- dmlab_observations: Any = None#
- dmlab_config: Any = None#
- dmlab_renderer: str = 'hardware'#
- procgen_distribution_mode: str = 'easy'#
- procgen_num_levels: int = 0#
- procgen_start_level: Any = None#
- mujoco_xml_path: Any = None#
- mujoco_xml_string: Any = None#
- mujoco_binary_path: Any = None#
- mujoco_camera: Any = None#
- mujoco_frame_skip: int = 1#
- mujoco_reset_noise_scale: float = 0.0#
- brax_backend: str = 'generalized'#
- brax_jit: bool = True#
- brax_auto_reset: bool = False#
- brax_suppress_warp_warnings: bool = True#
- unity_file_name: Any = None#
- unity_behavior_name: Any = None#
- unity_worker_id: int = 0#
- unity_base_port: int = 5005#
- unity_no_graphics: bool = True#
- unity_time_scale: float = 20.0#
- unity_quality_level: int = 1#
- algo: str = 'Dreamerv1'#
- exp_name: str = 'lr1e-3'#
- train: bool = True#
- evaluate: bool = False#
- seed: int = 1#
- no_gpu: bool = False#
- max_episode_length: int = 1000#
- buffer_size: int = 800000#
- time_limit: int = 1000#
- cnn_activation_function: str = 'relu'#
- dense_activation_function: str = 'elu'#
- obs_embed_size: int = 1024#
- num_units: int = 400#
- deter_size: int = 200#
- stoch_size: int = 30#
- action_repeat: int = 2#
- action_noise: float = 0.3#
- total_steps: int = 5000000#
- seed_steps: int = 5000#
- update_steps: int = 100#
- collect_steps: int = 1000#
- batch_size: int = 50#
- train_seq_len: int = 50#
- imagine_horizon: int = 15#
- use_disc_model: bool = False#
- free_nats: float = 3.0#
- discount: float = 0.99#
- td_lambda: float = 0.95#
- kl_loss_coeff: float = 1.0#
- kl_alpha: float = 0.8#
- disc_loss_coeff: float = 10.0#
- num_buckets: int = 255#
- symlog_range: float = 10.0#
- model_learning_rate: float = 0.0006#
- actor_learning_rate: float = 8e-05#
- value_learning_rate: float = 8e-05#
- adam_epsilon: float = 1e-07#
- grad_clip_norm: float = 100.0#
- use_amp: bool = True#
- test: bool = False#
- test_interval: int = 10000#
- test_episodes: int = 10#
- scalar_freq: int = 1000#
- log_video_freq: int = -1#
- max_videos_to_save: int = 2#
- video_format: str = 'gif'#
- video_fps: int = 20#
- checkpoint_interval: int = 10000#
- checkpoint_path: str = ''#
- restore: bool = False#
- experience_replay: str = ''#
- render: bool = False#
- enable_wandb: bool = False#
- wandb_api_key: str = ''#
- wandb_project: str = 'torchwm'#
- wandb_entity: str = ''#
- log_dir: str = 'runs'#
- logdir: Any = None#
- data_dir: Any = None#
- log_level: str = 'INFO'#
- log_file: Any = None#
- enable_tensorboard: bool = False#
- enable_console_metrics: bool = True#
- enable_jsonl: bool = True#
- jsonl_filename: str = 'metrics.jsonl'#
- log_system_stats_freq: int = 1000#
- detect_anomaly: bool = False#
- class world_models.JEPAConfig[source]#
Bases:
SerializableConfigMixinMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- classmethod from_dict(values)[source]#
Load flat field values or the nested trainer dictionary.
- Parameters:
values (Dict[str, Any])
- Return type:
- class world_models.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
SerializableConfigMixinDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via
__getattr__andget_dit_config().- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- world_models.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
Both UPPER_CASE and snake_case override keys are accepted.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)
- Parameters:
overrides (Any)
- Return type:
- class world_models.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
SerializableConfigMixin- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
data_loader_num_workers (int)
pin_memory (bool)
persistent_workers (bool)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
use_amp (bool)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- data_loader_num_workers: int = 4#
- pin_memory: bool = True#
- persistent_workers: bool = True#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- use_amp: bool = True#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
- class world_models.IRISConfig[source]#
Bases:
SerializableConfigMixinConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class world_models.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
SerializableConfigMixinConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class world_models.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
SerializableConfigMixinConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class world_models.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.OperatorABC(*, device=None)[source]#
Bases:
Module,ABCStructured base class for inference operators.
Operators use a consistent pipeline:
preprocessconverts raw inputs into tensors.forwardperforms model/operator-specific tensor computation.postprocessformats the final output mapping.
Subclasses may also declare
input_specsandoutput_specsto validate required tensor keys, shapes, and dtypes.OperatorABCinherits fromtorch.nn.Module, so operators supportto(device),train(), andeval()just like model modules.- Parameters:
device (torch.device | str | None)
- input_specs: Mapping[str, TensorSpec] = {}#
- output_specs: Mapping[str, TensorSpec] = {}#
- abstractmethod preprocess(inputs)[source]#
Convert raw inputs into a tensor mapping ready for
forward.- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- forward(inputs)[source]#
Run tensor computation for this operator.
Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.
- Parameters:
inputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- postprocess(outputs)[source]#
Format validated forward outputs for consumers.
- Parameters:
outputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- process(inputs)[source]#
Process raw inputs through preprocess, forward, and postprocess stages.
- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- batch(inputs)[source]#
Preprocess a sequence of inputs and stack matching tensor keys.
- Parameters:
inputs (Sequence[Any])
- Return type:
dict[str, Tensor]
- to(*args, **kwargs)[source]#
Move module parameters/buffers and remember the target tensor device.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
- classmethod validate_mapping(values, specs, *, label)[source]#
Validate tensor keys, shapes, and dtypes against optional specs.
- Parameters:
values (Mapping[str, Tensor])
specs (Mapping[str, TensorSpec])
label (str)
- Return type:
None
- class world_models.TensorSpec(shape=None, dtype=None, required=True)[source]#
Bases:
objectOptional tensor contract used to validate operator inputs or outputs.
- Parameters:
shape (tuple[int | None, ...] | None) – Expected shape. Use
Noneas a wildcard for dimensions that may vary, such as batch size.dtype (dtype | None) – Expected tensor dtype.
required (bool) – Whether the key must be present in the mapping being validated.
- shape: tuple[int | None, ...] | None = None#
- dtype: dtype | None = None#
- required: bool = True#
- class world_models.DreamerOperator(image_size=64, action_dim=6)[source]#
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- class world_models.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
- class world_models.IrisOperator(seq_length=512, vocab_size=32000)[source]#
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- class world_models.PlaNetOperator(state_dim=32, action_dim=4)[source]#
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- world_models.get_operator(name, **kwargs)[source]#
Factory function to get inference operators by name.
- Parameters:
name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’
**kwargs (Any) – Operator-specific configuration
- Returns:
Configured OperatorABC instance
- Return type:
Example
>>> op = get_operator('dreamer', image_size=64, action_dim=6) >>> processed = op.process({'image': image, 'action': action})
- class world_models.RewardModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModulePredict scalar rewards from Dreamer latent belief and state vectors.
Implemented as an MLP used for model-based reward supervision and imagined rollout return estimation.
- Parameters:
belief_size (int)
state_size (int)
hidden_size (int)
activation_function (str)
- class world_models.ValueModel(belief_size, state_size, hidden_size, activation_function='relu')[source]#
Bases:
ModuleEstimate scalar value from Dreamer latent belief and state vectors.
This MLP is trained on imagined returns and used for actor/value updates.
- Parameters:
belief_size (int)
state_size (int)
hidden_size (int)
activation_function (str)
- world_models.DreamerRewardModel#
alias of
RewardModel
- world_models.DreamerValueModel#
alias of
ValueModel
User-facing convenience APIs for TorchWM.
The lower-level modules remain available for research workflows, but this module collects the common discovery and construction paths behind small, predictable factory functions.
- class world_models.api.EnvBackendSpec(name, factory_path, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing an environment backend available through
make_env.- Parameters:
name (str)
factory_path (str)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- factory_path: str#
Alias for field number 1
- description: str#
Alias for field number 2
- aliases: tuple[str, ...]#
Alias for field number 3
- class world_models.api.ModelSpec(name, import_path, config_path=None, description='', aliases=())[source]#
Bases:
NamedTupleMetadata describing a model available through
create_model().- Parameters:
name (str)
import_path (str)
config_path (str | None)
description (str)
aliases (tuple[str, ...])
- name: str#
Alias for field number 0
- import_path: str#
Alias for field number 1
- config_path: str | None#
Alias for field number 2
- description: str#
Alias for field number 3
- aliases: tuple[str, ...]#
Alias for field number 4
- world_models.api.create_config(model, **overrides)[source]#
Create the default config object for
modeland apply overrides.Examples
>>> cfg = create_config("dreamer", env="walker-walk", seed=7) >>> cfg.env 'walker-walk'
- Parameters:
model (str)
overrides (Any)
- Return type:
Any
- world_models.api.create_model(model, config=None, **overrides)[source]#
Instantiate a model or agent from a simple string name.
configis optional for models that define a config class. Keyword overrides are applied to the config when possible, otherwise they are passed directly to the underlying constructor/factory.Examples
>>> agent = create_model("dreamer", env="walker-walk", total_steps=1000) >>> genie = create_model("genie-small", image_size=32)
- Parameters:
model (str)
config (Any | None)
overrides (Any)
- Return type:
Any
- world_models.api.get_env_backend_spec(name)[source]#
Return metadata for an environment backend name or alias.
- Parameters:
name (str)
- Return type:
- world_models.api.get_model_spec(name)[source]#
Return metadata for a model name or alias.
- Parameters:
name (str)
- Return type:
- world_models.api.list_env_backends()[source]#
Return canonical backend names accepted by
make_env().- Return type:
list[str]
- world_models.api.list_envs(model=None)[source]#
List known environment ids, optionally filtered by model family.
- Parameters:
model (str | None)
- Return type:
list[str] | dict[str, list[str]]
- world_models.api.list_models()[source]#
Return canonical model names accepted by
create_model().- Return type:
list[str]
- world_models.api.make_env(env_id, backend='auto', **kwargs)[source]#
Create an environment with a consistent TorchWM entry point.
- Parameters:
env_id (str) – Environment id, XML path, Unity executable path, or backend-specific id.
backend (str) – One of
list_env_backends();"auto"tries TorchWM’s compatibility helper.**kwargs (Any) – Backend-specific options.
- Return type:
Any
Export utilities for production deployment.
The public entry point is obj.export(path, format="onnx"). Importing this
module installs that method on every torch.nn.Module once, so all TorchWM
models get ONNX, TorchScript, and TensorRT export support without each model
subclassing a TorchWM-specific base class. Non-nn.Module agent wrappers can
inherit ExportableAgentMixin, which uses the same resolver and exporter.
- class world_models.export.DreamerPolicyExport(actor)[source]#
Bases:
ModuleTraceable Dreamer policy head used by the generic export resolver.
- Parameters:
actor (nn.Module)
- class world_models.export.ExportableAgentMixin[source]#
Bases:
objectMixin for non-
nn.Moduleagents that delegates to the shared exporter.- export(path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export this agent or one of its contained modules for deployment.
- Parameters:
path (str | Path)
format (str)
example_inputs (Any | None)
target (str | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- class world_models.export.IRISActorCriticExport(agent)[source]#
Bases:
ModuleTraceable IRIS policy/value head used by the generic export resolver.
- Parameters:
agent (Any)
- world_models.export.export_any(obj, path, format='onnx', *, example_inputs=None, target=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export any TorchWM model/agent or a target module contained by it.
- Parameters:
obj (Any)
path (str | Path)
format (str)
example_inputs (Any | None)
target (str | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- world_models.export.export_model(module, path, format='onnx', *, example_inputs=None, input_names=None, output_names=None, dynamic_axes=None, opset_version=17, **kwargs)[source]#
Export a
torch.nn.Moduleto ONNX, TorchScript, or TensorRT.- Parameters:
module (Module)
path (str | Path)
format (str)
example_inputs (Any | None)
input_names (list[str] | None)
output_names (list[str] | None)
dynamic_axes (dict[str, dict[int, str]] | None)
opset_version (int)
kwargs (Any)
- Return type:
Path
- world_models.export.install_export_method()[source]#
Install
torch.nn.Module.exportonce for every Torch model class.- Return type:
None
Models sub-module - Core world model implementations.
- Exported Components:
- Agents (High-level training wrappers):
DreamerAgent: High-level Dreamer training API
JEPAAgent: JEPA agent for self-supervised learning
Planet: PlaNet planning agent
VisionTransformer: Vision Transformer for image encoding
ModularRSSM: Modular RSSM with swappable components
Genie: Generative Interactive Environment model
- Core Models:
Dreamer: Core Dreamer implementation with RSSM, actor, critic
RSSM: Recurrent State-Space Model (Dreamer-style)
RecurrentStateSpaceModel: PlaNet-style RSSM
LatentActionModel: Latent action learning for Genie
DynamicsModel: Future frame prediction for Genie
- Factory Functions:
create_genie, create_genie_small, create_genie_large
create_modular_rssm
create_latent_action_model, create_dynamics_model
Small, import-safe catalog of available environments and backends.
This module replaces the previous world_models.ui.catalog and is safe to import from lightweight CLI tools and tests without pulling in any UI dependencies.
Model catalog#
Core model families#
Key classes: Dreamer, DreamerAgent, RSSM, RecurrentStateSpaceModel, Planet, ModularRSSM, JEPAAgent, VisionTransformer, IRISAgent, IRISTransformer, IRISWorldModel, Genie, LatentActionModel, and DynamicsModel.
- world_models.models.dreamer.get_available_memory()[source]#
Get available physical memory in bytes.
- Return type:
int
- world_models.models.dreamer.make_env(args)[source]#
Construct a Dreamer-compatible environment from DreamerConfig options.
Supports DMC, DMLab, Gym/Gymnasium, MuJoCo, Gymnasium Robotics, Procgen, Brax, BSuite, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit.
- Parameters:
args (Any)
- Return type:
Any
- world_models.models.dreamer.preprocess_obs(obs)[source]#
Convert raw uint8 image observations to Dreamer float input space.
Images are scaled from [0, 255] to roughly [-0.5, 0.5], matching the normalization expected by Dreamer encoders.
- Parameters:
obs (Tensor)
- Return type:
Tensor
- class world_models.models.dreamer.Dreamer(args, obs_shape, action_size, device, restore=False)[source]#
Bases:
objectCore Dreamer training system combining world model, actor, and value nets.
This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O.
- Parameters:
args (Any)
obs_shape (Any)
action_size (int)
device (device | str)
restore (bool)
- classmethod from_config(config=None, *, obs_shape=None, action_size=None, device=None, restore=None, **overrides)[source]#
Build a core Dreamer model from a config object, dict, or YAML file.
obs_shapeandaction_sizemay be supplied directly. When either is omitted, this method constructs a temporary environment from the config to infer the model shapes.- Parameters:
config (DreamerConfig | dict[str, Any] | str | Path | None)
obs_shape (tuple[int, ...] | None)
action_size (int | None)
device (str | device | None)
restore (bool | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Load a Dreamer checkpoint from a local path/directory or the HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
config (DreamerConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
map_location (str | device | None)
overrides (Any)
- Return type:
- parameter_count(trainable_only=False)[source]#
Return the total number of parameters owned by the Dreamer modules.
- Parameters:
trainable_only (bool)
- Return type:
int
- summary()[source]#
Return a compact parameter-count summary for the Dreamer modules.
- Return type:
dict[str, Any]
- world_model_loss(obs, acs, rews, nonterms)[source]#
- Parameters:
obs (Tensor)
acs (Tensor)
rews (Tensor)
nonterms (Tensor)
- Return type:
Tensor
- act_with_world_model(obs, prev_state, prev_action, explore=False)[source]#
- Parameters:
obs (Any)
prev_state (Any)
prev_action (Tensor)
explore (bool)
- Return type:
tuple
- act_and_collect_data(env, collect_steps)[source]#
- Parameters:
env (Any)
collect_steps (int)
- Return type:
ndarray
- evaluate(env, eval_episodes, render=False)[source]#
- Parameters:
env (Any)
eval_episodes (int)
render (bool)
- Return type:
tuple
- class world_models.models.dreamer.DreamerAgent(config=None, **kwargs)[source]#
Bases:
ExportableAgentMixinHigh-level user API for running Dreamer experiments end to end.
It builds environments from config, initializes seeds and logging, instantiates Dreamer, and exposes simple train() / evaluate() methods.
- Parameters:
config (Any)
kwargs (Any)
- classmethod from_config(config=None, **overrides)[source]#
Build a high-level Dreamer agent from a config object, dict, or YAML file.
- Parameters:
config (DreamerConfig | dict[str, Any] | str | Path | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Create a Dreamer agent and restore weights from a local path or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
config (DreamerConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
map_location (str | device | None)
overrides (Any)
- Return type:
- parameter_count(trainable_only=False)[source]#
Return the total number of Dreamer parameters.
- Parameters:
trainable_only (bool)
- Return type:
int
- class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#
Bases:
ModuleRecurrent State-Space Model used by Dreamer for latent dynamics learning.
The RSSM is the core world model component that learns compact representations of environment dynamics. It maintains a hybrid state consisting of:
Deterministic State (h) – A recurrent hidden state updated by a GRU, capturing sequential/temporal information and deterministic transitions.
Stochastic State (s) – A latent variable representing stochastic, multi-modal uncertainty in the environment (e.g., ambiguous observations).
The model operates in two modes:
Observe Mode – Updates states using actual observations from the environment. Uses the representation model: p(s_t | h_t, obs_t)
Imagine Mode – Predicts future states without observations. Uses the transition/prior model: p(s_t | h_t)
Architecture
Input: Previous state (h_{t-1}, s_{t-1}) and action a_{t-1}
Process: GRU updates deterministic state, MLP computes stochastic prior/posterior
Output: Updated state (h_t, s_t) and distributions
State Representation
deter (h): GRU hidden state, captures sequential context
stoch (s): Stochastic latent, multi-modal uncertainty
mean/std: Parameters of the stochastic distribution
Usage with DreamerAgent:
rssm = RSSM( action_size=action_dim, stoch_size=30, # Stochastic state dimension deter_size=200, # Deterministic (GRU) state dimension hidden_size=200, # MLP hidden layer size obs_embed_size=256, # Observation embedding from encoder activation='elu' ) # Observe with actual observation posterior = rssm.observe_step(prev_state, prev_action, obs_embed) # Imagine future without observation prior = rssm.imagine_step(current_state, action)
Training
The RSSM is trained by maximizing the ELBO (Evidence Lower Bound): - KL divergence between prior and posterior encourages the prior to
capture environment dynamics
Reconstruction loss from decoder ensures state captures observation info
- Reference:
Dreamer: Scalable Reinforcement Learning Using World Models Hafner et al., 2020 - https://arxiv.org/abs/1912.01603
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
hidden_size (int)
obs_embed_size (int)
activation (str)
- init_state(batch_size, device)[source]#
Initialize RSSM state with zeros.
- Parameters:
batch_size (int) – Number of parallel sequences
device (device) – torch device for tensors
- Returns:
mean, std: Stochastic distribution parameters
stoch: Stochastic state sample
deter: Deterministic GRU hidden state
- Return type:
Dictionary containing zero-initialized state components
- get_dist(mean, std)[source]#
Create an Independent Normal distribution from mean and std.
- Parameters:
mean (Tensor) – Location parameter
std (Tensor) – Scale parameter
- Returns:
Independent Normal distribution with given parameters
- Return type:
Independent
- observe_step(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#
Update state using actual observation (observe mode).
In observe mode, the RSSM first computes a transition prior from the previous state and action, then refines the stochastic state using the actual observation embedding to form the posterior.
- Parameters:
prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)
obs_embed (Tensor) – Observation embedding from encoder, shape (B, obs_embed_size)
nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
A tuple
(posterior, prior)of state dictionaries. The posterior incorporates observation information; the prior is the transition prediction before observation. Both share the same deterministic state because the GRU is only advanced once per timestep.- Return type:
Tuple[dict, dict]
- imagine_step(prev_state, prev_action, nonterm=tensor(1.))[source]#
Predict next state without observation (imagine mode).
In imagine mode, the RSSM predicts future states using only the prior distribution. This is used for planning and policy learning where actual observations are not available.
- Parameters:
prev_state (dict) – Dictionary with ‘deter’ (h_{t-1}) and ‘stoch’ (s_{t-1})
prev_action (Tensor) – Previous action a_{t-1}, shape (B, action_size)
nonterm (Tensor) – Termination mask (1.0 = continue, 0.0 = terminal)
- Returns:
deter: Predicted deterministic state
mean, std, stoch: Prior stochastic state distribution
- Return type:
Dictionary with predicted state containing
- get_prior(prev_state, prev_action, nonterm=tensor(1.))[source]#
- Parameters:
prev_state (dict)
prev_action (Tensor)
nonterm (Tensor)
- Return type:
dict
- get_posterior(prev_state, prev_action, obs_embed, nonterm=tensor(1.))[source]#
Compute posterior distribution over stochastic state.
The posterior incorporates observation information to produce a more accurate state estimate.
- Parameters:
prev_state (dict) – Previous state dictionary
prev_action (Tensor) – Previous action
obs_embed (Tensor) – Observation embedding
nonterm (Tensor) – Termination mask
- Returns:
Dictionary with posterior state (observation-informed). Note that the previous-state shape
(B, ...)is preserved; the batch dimension is not flattened.- Return type:
dict
- detach_state(state)[source]#
Detach state tensors from computation graph.
Used during DreamerV2 training to prevent gradient flow through the observation/update pathway.
- Parameters:
state (dict) – State dictionary with tensor values
- Returns:
Detached state dictionary
- Return type:
dict
- seq_to_batch(state_dict)[source]#
Convert sequence state to batch format.
- Parameters:
state_dict (dict) – Dictionary with sequence-dimension tensors (T, B, …)
- Returns:
Dictionary with batch-dimension tensors (B*T, …)
- Return type:
dict
- observe_rollout(obs_embed, actions, nonterms, init_state, seq_len)[source]#
Process a sequence of observations (observe mode rollout).
At each timestep we run
observe_steponce to obtain the transition prior (the prediction given the previous state and action) and the observation-informed posterior. The posterior is then used as the previous state for the next step, matching the standard Dreamer inference pattern.- Parameters:
obs_embed (Tensor) – Observation embeddings, shape (T+1, B, obs_embed_size)
actions (Tensor) – Actions, shape (T, B, action_size)
nonterms (Tensor) – Non-termination flags, shape (T, B, 1)
init_state (dict) – Initial state dictionary
seq_len (int) – Sequence length T
- Returns:
Dictionary with prior states stacked along the time axis posterior: Dictionary with posterior states stacked along the time axis
- Return type:
prior
- imagine_rollout(policy, init_state, horizon)[source]#
Generate imagined trajectory using policy (imagine mode rollout).
- Parameters:
policy (Module) – Actor network that outputs actions from state features
init_state (dict) – Initial state dictionary
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined states for each step
- Return type:
dict
- forward(x, u)[source]#
Forward pass for training (computes sequence of states).
- Parameters:
x (Tensor) – Observations, shape (B, T+1, C, H, W)
u (Tensor) – Actions, shape (B, T, action_size)
- Returns:
List of state dictionaries for each timestep priors: List of prior distributions (tuples of mean, std) posteriors: List of posterior distributions (tuples of mean, std)
- Return type:
states
- class world_models.models.rssm.RecurrentStateSpaceModel(action_size, state_size=200, latent_size=30, hidden_size=200, embed_size=1024, activation_function='relu')[source]#
Bases:
ModuleA Recurrent State Space Model (RSSM) for modeling latent dynamics in sequential data.
- Parameters:
action_size (int)
state_size (int)
latent_size (int)
hidden_size (int)
embed_size (int)
activation_function (str)
- get_init_state(enc, h_t=None, s_t=None, a_t=None, mean=False)[source]#
Returns the initial posterior given the observation.
- Parameters:
enc (Tensor)
h_t (Tensor | None)
s_t (Tensor | None)
a_t (Tensor | None)
mean (bool)
- Return type:
tuple[Tensor, Tensor]
- deterministic_state_fwd(h_t, s_t, a_t)[source]#
Deterministic transition update.
Ensures a_t is 2D and matches batch dimension of h_t before concatenation. Accepts a_t shaped [B, action_size], [action_size] (expanded to [B, action_size]), or [B]/scalar (reshaped appropriately).
- Parameters:
h_t (Tensor)
s_t (Tensor)
a_t (Tensor)
- Return type:
Tensor
- state_prior(h_t, sample=False)[source]#
Returns the prior distribution over the latent state given the deterministic state
- Parameters:
h_t (Tensor)
sample (bool)
- Return type:
tuple[Tensor, Tensor] | Tensor
- state_posterior(h_t, e_t, sample=False)[source]#
Returns the state prior given the deterministic state and obs
- Parameters:
h_t (Tensor)
e_t (Tensor)
sample (bool)
- Return type:
tuple[Tensor, Tensor] | Tensor
- rollout_prior(act, h_t, s_t)[source]#
- Parameters:
act (Tensor)
h_t (Tensor)
s_t (Tensor)
- Return type:
tuple[Tensor, Tensor]
- forward(x, u)[source]#
Forward through the RSSM for a batch of sequences.
- Parameters:
x (Tensor) – Tensor [B, T+1, C, H, W] (observations including initial frame)
u (Tensor) – Tensor [B, T, action_size] (actions for T steps)
- Returns:
list[T] of tensors [B, state_size] priors: list[T] of tuples (mean, std) each [B, latent_size] posteriors: list[T] of tuples (mean, std) each [B, latent_size]
- Return type:
states
- class world_models.models.planet.Planet(env, bit_depth=5, device=None, state_size=200, latent_size=30, embedding_size=1024, memory_size=100, policy_cfg=None, headless=False, max_episode_steps=None, action_repeats=1, results_dir=None)[source]#
Bases:
ExportableAgentMixinHigh-level Planet wrapper.
- Usage example:
from world_models.models.planet import Planet p = Planet(env=’CartPole-v1’, bit_depth=5) p.train(epochs=50)
- Parameters:
env (Any)
bit_depth (int)
device (device | None)
state_size (int)
latent_size (int)
embedding_size (int)
memory_size (int)
policy_cfg (dict | None)
headless (bool)
max_episode_steps (int | None)
action_repeats (int)
results_dir (str | None)
- warmup(n_episodes=1, random_policy=True)[source]#
Collect n_episodes of rollouts into memory (used as warmup).
- Parameters:
n_episodes (int)
random_policy (bool)
- Return type:
None
- train(epochs=100, steps_per_epoch=150, batch_size=32, H=50, beta=1.0, save_every=25, record_grads=False, results_dir=None, scheduler_type='step', scheduler_kwargs=None)[source]#
High-level training loop. Delegates single-step training to the existing train function.
- Parameters:
scheduler_type (str) – Type of scheduler to use (“step”, “cosine”, “exponential”, “plateau”, None)
scheduler_kwargs (dict) – Additional arguments for the scheduler
epochs (int)
steps_per_epoch (int)
batch_size (int)
H (int)
beta (float)
save_every (int)
record_grads (bool)
results_dir (str | None)
- Return type:
str
Mixture Density Recurrent Neural Network (MDRNN) model implementation.
This module provides implementations of MDRNN models for world modeling. The MDRNN is used to predict future latent states given current latent states and actions, using a Gaussian Mixture Model (GMM) for the output.
- Reference:
Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111
- class world_models.models.mdrnn.MDRNN(latents, actions, hiddens, gaussians)[source]#
Bases:
_MDRNNBaseMDRNN model for multi-step sequence prediction.
This model processes entire sequences of latent states and actions, predicting the next latent state using a Gaussian Mixture Model (GMM). It also predicts rewards and terminal states.
- Parameters:
latents (int) – Dimensionality of latent space (input and output).
actions (int) – Dimensionality of action space.
hiddens (int) – Number of hidden units in LSTM.
gaussians (int) – Number of Gaussian components in GMM output.
Example
>>> mdrnn = MDRNN(latents=32, actions=3, hiddens=256, gaussians=5) >>> actions = torch.randn(10, 4, 3) # seq_len, batch, action >>> latents = torch.randn(10, 4, 32) # seq_len, batch, latent >>> mus, sigmas, logpi, rs, ds = mdrnn(actions, latents) >>> # mus.shape = (10, 4, 5, 32)
- forward(actions, latents)[source]#
Multi-step forward pass through the MDRNN.
- Parameters:
actions (Tensor) – (SEQ_LEN, BSIZE, ASIZE) Tensor of actions.
latents (Tensor) – (SEQ_LEN, BSIZE, LSIZE) Tensor of latent states.
- Returns:
mus: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) GMM means
sigmas: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) GMM standard deviations
logpi: (SEQ_LEN, BSIZE, N_GAUSS) log GMM weights
rs: (SEQ_LEN, BSIZE) predicted rewards
ds: (SEQ_LEN, BSIZE) predicted terminal state logits
- Return type:
Tuple of
Return initial hidden state for the LSTM.
- Parameters:
batch_size (int) – Number of sequences in the batch.
- Returns:
Tuple of (h, c) with shapes (batch_size, hiddens).
- Return type:
tuple[Tensor, Tensor]
- class world_models.models.mdrnn.MDRNNCell(latents, actions, hiddens, gaussians)[source]#
Bases:
_MDRNNBaseMDRNN model for single-step forward prediction.
This model processes a single step of latent state and action,
This model processes a single step of latent state and action, predicting the next latent state using a Gaussian Mixture Model (GMM). It also predicts rewards and terminal states. Useful for real-time inference.
- Parameters:
latents (int) – Dimensionality of latent space (input and output).
actions (int) – Dimensionality of action space.
hiddens (int) – Number of hidden units in LSTMCell.
gaussians (int) – Number of Gaussian components in GMM output.
Example
>>> cell = MDRNNCell(latents=32, actions=3, hiddens=256, gaussians=5) >>> action = torch.randn(4, 3) # batch, action >>> latent = torch.randn(4, 32) # batch, latent >>> hidden = (torch.randn(4, 256), torch.randn(4, 256)) >>> mus, sigmas, logpi, r, d, next_hidden = cell(action, latent, hidden)
- forward(action, latent, hidden)[source]#
Single-step forward pass through the MDRNN cell.
- Parameters:
action (Tensor) – (BSIZE, ASIZE) Tensor of actions for current batch.
latent (Tensor) – (BSIZE, LSIZE) Tensor of latent states for current batch.
hidden (Tuple[Tensor, Tensor]) – Tuple of (h, c) hidden states for LSTMCell.
- Returns:
mus: (BSIZE, N_GAUSS, LSIZE) GMM means
sigmas: (BSIZE, N_GAUSS, LSIZE) GMM standard deviations
logpi: (BSIZE, N_GAUSS) log GMM weights
r: (BSIZE,) predicted rewards
d: (BSIZE,) predicted terminal state logits
next_hidden: Tuple of (h, c) next hidden states
- Return type:
Tuple of
- Parameters:
batch_size (int)
- Return type:
Tuple[Tensor, Tensor]
Linear Controller for World Models.
This module provides a simple linear controller that maps latent states and recurrent hidden states to actions. The controller is trained using CMA-ES (Covariance Matrix Adaptation Evolution Strategy).
- Reference:
Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111
- class world_models.models.controller.Controller(latent_size, hidden_size, action_size)[source]#
Bases:
ModuleLinear controller that maps latent + hidden state to actions.
This is a simple linear controller that takes the latent state and recurrent hidden state as input and outputs actions. It is trained separately from the world model using black-box optimization (CMA-ES).
- Parameters:
latent_size (int)
hidden_size (int)
action_size (int)
- latent_size#
Dimensionality of latent state from VAE.
Dimensionality of RSSM hidden state.
- action_size#
Dimensionality of action space.
Example
>>> controller = Controller(latent_size=32, hidden_size=200, action_size=3) >>> state = torch.cat([latent, hidden], dim=-1) >>> action = controller(state)
Modular RSSM with swappable encoder/decoder/backbone components.
This module provides a flexible architecture for world model research, allowing researchers to easily swap different encoder, decoder, and backbone implementations for ablations and experimentation.
- class world_models.models.modular_rssm.EncoderBase(*args, **kwargs)[source]#
Bases:
Module,ABCAbstract base class for observation encoders.
- Parameters:
args (Any)
kwargs (Any)
- abstractmethod forward(obs)[source]#
Encode observations to embeddings.
- Parameters:
obs (Tensor)
- Return type:
Tensor
- embed_size: int#
- class world_models.models.modular_rssm.DecoderBase(*args, **kwargs)[source]#
Bases:
Module,ABCAbstract base class for observation decoders.
- Parameters:
args (Any)
kwargs (Any)
- class world_models.models.modular_rssm.BackboneBase(*args, **kwargs)[source]#
Bases:
Module,ABCAbstract base class for recurrent dynamics backbones.
- Parameters:
args (Any)
kwargs (Any)
- abstractmethod forward(state, action, obs_embed=None, nonterm=1.0)[source]#
Process one step of dynamics. Returns (prior, posterior).
- Parameters:
state (Dict[str, Tensor])
action (Tensor)
obs_embed (Tensor | None)
nonterm (float)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- abstractmethod init_state(batch_size, device)[source]#
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- stoch_size: int#
- deter_size: int#
- class world_models.models.modular_rssm.ConvEncoder(input_shape, embed_size, activation='elu', depth=32)[source]#
Bases:
EncoderBaseConvolutional encoder from Dreamer (image observations).
- Parameters:
input_shape (Tuple[int, int, int])
embed_size (int)
activation (str)
depth (int)
- class world_models.models.modular_rssm.MLPEncoder(input_dim, embed_size, hidden_sizes=[256, 256], activation='elu')[source]#
Bases:
EncoderBaseMLP encoder for state-based observations.
- Parameters:
input_dim (int)
embed_size (int)
hidden_sizes (List[int])
activation (str)
- class world_models.models.modular_rssm.ViTEncoder(input_shape, embed_size, patch_size=8, depth=6, num_heads=8, mlp_ratio=4.0, activation='gelu')[source]#
Bases:
EncoderBaseVision Transformer encoder for image observations.
- Parameters:
input_shape (Tuple[int, int, int])
embed_size (int)
patch_size (int)
depth (int)
num_heads (int)
mlp_ratio (float)
activation (str)
- class world_models.models.modular_rssm.TransformerBlock(embed_size, num_heads, mlp_ratio, activation)[source]#
Bases:
ModuleTransformer block for ViT encoder.
- Parameters:
embed_size (int)
num_heads (int)
mlp_ratio (float)
activation (str)
- class world_models.models.modular_rssm.ConvDecoder(stoch_size, deter_size, output_shape, activation='elu', depth=32)[source]#
Bases:
DecoderBaseConvolutional decoder for image observations.
- Parameters:
stoch_size (int)
deter_size (int)
output_shape (Tuple[int, int, int])
activation (str)
depth (int)
- class world_models.models.modular_rssm.MLPDecoder(stoch_size, deter_size, output_dim, hidden_sizes=[256, 256], activation='elu', dist='normal')[source]#
Bases:
DecoderBaseMLP decoder for state-based observations.
- Parameters:
stoch_size (int)
deter_size (int)
output_dim (int)
hidden_sizes (List[int])
activation (str)
dist (str)
- class world_models.models.modular_rssm.GRUBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#
Bases:
BackboneBaseGRU-based recurrent dynamics backbone (standard RSSM).
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
hidden_size (int)
embed_size (int)
activation (str)
- property embedding_size: int#
- class world_models.models.modular_rssm.LSTMBackbone(action_size, stoch_size, deter_size, hidden_size, embed_size, activation='elu')[source]#
Bases:
BackboneBaseLSTM-based recurrent dynamics backbone.
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
hidden_size (int)
embed_size (int)
activation (str)
- property embedding_size: int#
- class world_models.models.modular_rssm.TransformerBackbone(action_size, stoch_size, deter_size, embed_size, num_heads=4, num_layers=2, activation='gelu')[source]#
Bases:
BackboneBaseTransformer-based dynamics backbone for long-range dependencies.
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
embed_size (int)
num_heads (int)
num_layers (int)
activation (str)
- property embedding_size: int#
- class world_models.models.modular_rssm.ModularRSSM(encoder, decoder, backbone, reward_decoder=None)[source]#
Bases:
ModuleModular RSSM with swappable encoder, decoder, and backbone.
This class allows researchers to easily experiment with different: - Encoders: Conv, MLP, ViT - Decoders: Conv, MLP - Backbones: GRU, LSTM, Transformer
Example
>>> encoder = ConvEncoder((3, 64, 64), embed_size=1024) >>> decoder = ConvDecoder(32, 200, (3, 64, 64)) >>> backbone = GRUBackbone(action_size=6, stoch_size=32, deter_size=200, hidden_size=200, embed_size=1024) >>> rssm = ModularRSSM(encoder, decoder, backbone)
- Parameters:
encoder (EncoderBase)
decoder (DecoderBase)
backbone (BackboneBase)
reward_decoder (DecoderBase | None)
- property stoch_size: int#
- property deter_size: int#
- property embed_size: int#
- init_state(batch_size, device)[source]#
- Parameters:
batch_size (int)
device (device)
- Return type:
Dict[str, Tensor]
- observe_step(prev_state, prev_action, obs, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
obs (Tensor)
nonterm (Any)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
- Parameters:
prev_state (Dict[str, Tensor])
prev_action (Tensor)
nonterm (Any)
- Return type:
Dict[str, Tensor]
- observe_rollout(obs, actions, nonterms, prev_state, horizon)[source]#
- Parameters:
obs (Tensor)
actions (Tensor)
nonterms (Tensor)
prev_state (Dict[str, Tensor])
horizon (int)
- Return type:
Tuple[Dict[str, Tensor], Dict[str, Tensor]]
- world_models.models.modular_rssm.create_modular_rssm(encoder_type='conv', decoder_type='conv', backbone_type='gru', obs_shape=(3, 64, 64), action_size=6, stoch_size=32, deter_size=200, embed_size=1024, hidden_size=200, activation='elu', **kwargs)[source]#
Factory function to create a modular RSSM with specified components.
- Parameters:
encoder_type (str) – Type of encoder (“conv”, “mlp”, “vit”)
decoder_type (str) – Type of decoder (“conv”, “mlp”)
backbone_type (str) – Type of backbone (“gru”, “lstm”, “transformer”)
obs_shape (Tuple[int, int, int] | Tuple[int]) – Shape of observations (C, H, W) for images or (D,) for state
action_size (int) – Action space dimension
stoch_size (int) – Stochastic latent dimension
deter_size (int) – Deterministic hidden dimension
embed_size (int) – Encoder embedding dimension
hidden_size (int) – Hidden layer dimension
activation (str) – Activation function name
kwargs (Any)
- Returns:
Configured ModularRSSM instance
- Return type:
- class world_models.models.jepa_agent.JEPAAgent(config=None, **kwargs)[source]#
Bases:
ExportableAgentMixinConvenience interface for configuring and launching JEPA training runs.
Accepts a JEPAConfig plus keyword overrides, prepares output folders, and delegates execution to the JEPA training entrypoint.
- Parameters:
config (JEPAConfig | None)
kwargs (Any)
- classmethod from_config(config=None, **overrides)[source]#
Build a JEPA agent from a config object, dict, YAML file, or YAML string.
- Parameters:
config (JEPAConfig | dict[str, Any] | str | Path | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#
Create a JEPA agent from local/HF Hub config and checkpoint metadata.
- Parameters:
pretrained_model_name_or_path (str | Path)
config (JEPAConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
overrides (Any)
- Return type:
- world_models.models.vit.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Generate fixed 2D sine/cosine positional embeddings on a square patch grid.
Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding.
- Parameters:
embed_dim (int)
grid_size (int)
cls_token (bool)
- Return type:
ndarray
- world_models.models.vit.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)[source]#
Build 2D sine/cosine embeddings from precomputed meshgrid coordinates.
The final embedding concatenates independent encodings for vertical and horizontal coordinates.
- Parameters:
embed_dim (int)
grid (ndarray)
- Return type:
ndarray
- world_models.models.vit.get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Generate 1D sine/cosine positional embeddings for integer positions.
Useful for sequence-style positional encoding and as a building block for 2D embedding construction.
- Parameters:
embed_dim (int)
grid_size (int)
cls_token (bool)
- Return type:
ndarray
- world_models.models.vit.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#
Generate 1D sine/cosine positional embeddings from explicit positions.
Positions are projected onto a log-frequency basis and encoded with sine and cosine components.
- Parameters:
embed_dim (int)
pos (ndarray)
- Return type:
ndarray
- world_models.models.vit.drop_path(x, drop_prob=0.0, training=False)[source]#
Apply stochastic depth (DropPath) regularization to residual branches.
Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude.
- Parameters:
x (Tensor)
drop_prob (float)
training (bool)
- Return type:
Tensor
- class world_models.models.vit.DropPath(drop_prob=None)[source]#
Bases:
ModuleModule wrapper around the functional drop_path stochastic depth utility.
- Parameters:
drop_prob (float | None)
- class world_models.models.vit.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#
Bases:
ModuleTwo-layer feed-forward network used inside transformer blocks.
Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern.
- Parameters:
in_features (int)
hidden_features (int | None)
out_features (int | None)
act_layer (type[Module])
drop (float)
- class world_models.models.vit.Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
ModuleMulti-head self-attention block for token sequences.
Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout.
- Parameters:
dim (int)
num_heads (int)
qkv_bias (bool)
qk_scale (float | None)
attn_drop (float)
proj_drop (float)
- class world_models.models.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleTransformer encoder block combining attention and MLP residual branches.
Each branch uses pre-normalization and optional stochastic depth.
- Parameters:
dim (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop (float)
attn_drop (float)
drop_path (float)
act_layer (type[Module])
norm_layer (type[Module])
- class world_models.models.vit.PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768)[source]#
Bases:
ModuleImage to Patch Embedding
- Parameters:
img_size (int)
patch_size (int)
in_chans (int)
embed_dim (int)
- class world_models.models.vit.ConvEmbed(channels, strides, img_size=224, in_chans=3, batch_norm=True)[source]#
Bases:
Module3x3 Convolution stems for ViT following ViTC models
- Parameters:
channels (list[int])
strides (list[int])
img_size (int)
in_chans (int)
batch_norm (bool)
- class world_models.models.vit.VisionTransformerPredictor(num_patches, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#
Bases:
ModuleVision Transformer
- Parameters:
num_patches (int)
embed_dim (int)
predictor_embed_dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
init_std (float)
kwargs (Any)
- class world_models.models.vit.VisionTransformer(img_size=[224], patch_size=16, in_chans=3, embed_dim=768, predictor_embed_dim=384, depth=12, predictor_depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#
Bases:
ModuleVision Transformer
- Parameters:
img_size (list[int])
patch_size (int)
in_chans (int)
embed_dim (int)
predictor_embed_dim (int)
depth (int)
predictor_depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
init_std (float)
kwargs (Any)
- world_models.models.vit.vit_predictor(**kwargs)[source]#
Factory for a JEPA predictor transformer with sensible defaults.
- Parameters:
kwargs (Any)
- Return type:
- world_models.models.vit.vit_tiny(patch_size=16, **kwargs)[source]#
Factory for a tiny Vision Transformer encoder backbone.
- Parameters:
patch_size (int)
kwargs (Any)
- Return type:
Any
- world_models.models.vit.vit_small(patch_size=16, **kwargs)[source]#
Factory for a small Vision Transformer encoder backbone.
- Parameters:
patch_size (int)
kwargs (Any)
- Return type:
Any
- world_models.models.vit.vit_base(patch_size=16, **kwargs)[source]#
Factory for a base Vision Transformer encoder backbone.
- Parameters:
patch_size (int)
kwargs (Any)
- Return type:
Any
- world_models.models.vit.vit_large(patch_size=16, **kwargs)[source]#
Factory for a large Vision Transformer encoder backbone.
- Parameters:
patch_size (int)
kwargs (Any)
- Return type:
Any
- world_models.models.vit.vit_huge(patch_size=16, **kwargs)[source]#
Factory for a huge Vision Transformer encoder backbone.
- Parameters:
patch_size (int)
kwargs (Any)
- Return type:
Any
- world_models.models.vit.vit_giant(patch_size=16, **kwargs)[source]#
Factory for a giant Vision Transformer encoder backbone.
- Parameters:
patch_size (int)
kwargs (Any)
- Return type:
Any
- world_models.models.iris_agent.compute_lambda_return(rewards, values, discounts, lambda_coef=0.95)[source]#
Compute λ-return target for value function training.
- Parameters:
rewards (Tensor) – Rewards (B, T)
values (Tensor) – Value estimates (B, T+1)
discounts (Tensor) – Discount factors (B, T)
lambda_coef (float) – Lambda parameter for bootstrapping
- Returns:
λ-return targets (B, T)
- Return type:
- class world_models.models.iris_agent.IRISAgent(config, action_size, device)[source]#
Bases:
ModuleComplete IRIS Agent with world model and policy.
Combines: - Discrete autoencoder (encoder + decoder) - Transformer world model - Actor-Critic for policy and value learning
- Parameters:
config (IRISConfig)
action_size (int)
device (device)
- classmethod from_config(config=None, *, action_size, device=None, **overrides)[source]#
Build an IRIS agent from a config object, dict, YAML file, or YAML string.
- Parameters:
config (IRISConfig | dict[str, Any] | str | Path | None)
action_size (int)
device (device | str | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, action_size=None, device=None, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#
Load an IRIS agent checkpoint from a local path/directory or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
action_size (int | None)
device (device | str | None)
config (IRISConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
overrides (Any)
- Return type:
- forward_actor_critic(frames, hidden=None)[source]#
Forward pass through actor-critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state
- Returns:
(B, T, action_size) values: (B, T) hidden_state: (h, c)
- Return type:
action_logits
- act(frame, epsilon=0.0, temperature=1.0)[source]#
Sample action from policy.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
epsilon (float) – Random action probability
temperature (float) – Action distribution temperature
- Returns:
Selected actions (B,)
- Return type:
- imagine_rollout(initial_frame, horizon=20)[source]#
Generate imagined trajectories using world model.
- Parameters:
initial_frame (Tensor) – Starting frame (B, C, H, W)
horizon (int) – Number of steps to imagine
- Returns:
Dictionary with imagined rollout data
- Return type:
trajectory
- update_autoencoder(frames)[source]#
Update discrete autoencoder.
- Parameters:
frames (Tensor) – Training frames (B, C, H, W)
- Returns:
Dictionary of loss values
- Return type:
losses
- update_transformer(frames, actions, rewards, terminals)[source]#
Update transformer world model.
- Parameters:
frames (Tensor) – Frame sequence
actions (Tensor) – Actions taken
rewards (Tensor) – Rewards received
terminals (Tensor) – Terminal flags
- Returns:
Dictionary of loss values
- Return type:
losses
- class world_models.models.iris_transformer.IRISTransformer(vocab_size=512, tokens_per_frame=16, action_size=18, embed_dim=256, num_layers=10, num_heads=4, dropout=0.1, gradient_checkpointing=False)[source]#
Bases:
ModuleAutoregressive Transformer for world modeling.
Models the dynamics of the environment by predicting: - Next frame tokens (transition model) - Rewards - Episode termination
The Transformer operates on sequences of interleaved frame tokens and actions.
- Parameters:
vocab_size (int)
tokens_per_frame (int)
action_size (int)
embed_dim (int)
num_layers (int)
num_heads (int)
dropout (float)
gradient_checkpointing (bool)
- forward(tokens, actions, mask=None)[source]#
Forward pass through the Transformer world model.
- Parameters:
tokens (Tensor) – Frame tokens (B, T, K) where T is timesteps
actions (Tensor) – Actions (B, T)
mask (Tensor | None) – Optional attention mask
- Returns:
Next token predictions (B, T, K, vocab_size) rewards: Predicted rewards (B, T) terminations: Predicted terminations (B, T, 2)
- Return type:
token_logits
- predict_next_tokens(tokens, actions)[source]#
Predict the next frame tokens autoregressively.
Used during imagination rollouts.
- Parameters:
tokens (Tensor) – Current frame tokens (B, K)
actions (Tensor) – Actions taken (B,)
- Returns:
Next frame token predictions (B, K, vocab_size) action_hidden: Hidden states for reward prediction (B, embed_dim)
- Return type:
token_logits
- sample_next_tokens(tokens, actions, temperature=1.0)[source]#
Sample next tokens from the distribution.
- Parameters:
tokens (Tensor) – Current frame tokens (B, K)
actions (Tensor) – Actions taken (B,)
temperature (float) – Sampling temperature (higher = more random)
- Returns:
Sampled token indices (B, K) log_probs: Log probabilities of sampled tokens (B, K)
- Return type:
sampled_tokens
- class world_models.models.iris_transformer.IRISWorldModel(encoder, decoder, transformer)[source]#
Bases:
ModuleComplete IRIS World Model combining autoencoder and transformer.
This is the core component that learns environment dynamics entirely in the “imaginary” latent space.
- Parameters:
encoder (Module)
decoder (Module)
transformer (IRISTransformer)
- forward(observations, actions)[source]#
Full world model forward pass.
- Parameters:
observations (Tensor) – Image sequence (B, T+1, C, H, W)
actions (Tensor) – Actions (B, T)
- Returns:
Dictionary with predicted tokens, rewards, terminations losses: Dictionary with loss components
- Return type:
predictions
- imagine(initial_tokens, policy, horizon=20, temperature=1.0)[source]#
Generate imagined trajectories.
- Parameters:
initial_tokens (Tensor) – Initial frame tokens (B, K)
policy (Module) – Policy network to sample actions
horizon (int) – Number of steps to imagine
temperature (float) – Sampling temperature for token prediction
- Returns:
Dictionary with imagined trajectories
- Return type:
imagined
- class world_models.models.genie.Genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_decoder_dim=1024, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, encoder_depth=12, decoder_depth=20, latent_action_depth=20, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleGenie: Generative Interactive Environment.
A generative model trained from video-only data that can be used as an interactive environment. Contains three key components: 1. Video Tokenizer: Converts raw video frames into discrete tokens 2. Latent Action Model (LAM): Infers latent actions between frames 3. Dynamics Model: Predicts future frames given past frames and latent actions
Based on “Genie: Generative Interactive Environments” paper (arXiv:2402.15391).
Training follows two phases as per paper: 1. Train video tokenizer first (on video tokens) 2. Co-train LAM (from pixels) and dynamics model (on video tokens)
The LAM uses VQ-VAE training with: - Encoder: Takes x1:t and x_{t+1} → outputs latent actions - Decoder: Takes x1:t-1 (masked) + actions → reconstructs x_t - Auxiliary variance loss to prevent action collapse
At inference, latent actions are stopgrad’d when passed to dynamics model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_decoder_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
encoder_depth (int)
decoder_depth (int)
latent_action_depth (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- classmethod from_config(config=None, **overrides)[source]#
Build Genie from a config object, dict, YAML file, or YAML string.
- Parameters:
config (GenieConfig | GenieSmallConfig | dict[str, Any] | str | Path | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Load Genie weights from a local path/directory or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
config (GenieConfig | dict[str, Any] | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
map_location (str | device | None)
overrides (Any)
- Return type:
- save_pretrained(path)[source]#
Save Genie weights and config in a from_pretrained-compatible format.
- Parameters:
path (str | Path)
- Return type:
None
- forward(video, mask_prob=0.5, training_phase='all')[source]#
Full forward pass through all components.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics (0.5-1.0)
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing losses and predictions
- Return type:
Dict[str, Tensor]
- training_step(video, mask_prob=0.5, training_phase='all')[source]#
Single training step computing all losses.
- Parameters:
video (Tensor) – (B, C, T, H, W) input video
mask_prob (float) – Probability for random masking in dynamics
training_phase (str) – “all”, “tokenizer”, or “lam_dynamics”
- Returns:
Dictionary containing all losses for backpropagation
- Return type:
Dict[str, Tensor]
- encode_video(video)[source]#
Encode video to discrete tokens.
- Parameters:
video (Tensor) – (B, C, T, H, W)
- Returns:
(B, T, H*W)
- Return type:
video_tokens
- infer_actions(frames)[source]#
Infer latent actions from a sequence of frames.
- Parameters:
frames (Tensor) – (B, C, T, H, W) video frames
- Returns:
(B, T-1) inferred latent action indices
- Return type:
latent_actions
- generate(prompt_frame, num_frames=16, actions=None, use_maskgit=True)[source]#
Generate video frames given a prompt frame and actions.
- Parameters:
prompt_frame (Tensor) – (B, C, H, W) initial frame
num_frames (int) – Total number of frames to generate
actions (Tensor | None) – (B, num_frames-1) latent action indices, or None for random
use_maskgit (bool) – Whether to use MaskGIT sampling
- Returns:
(B, C, num_frames, H, W)
- Return type:
generated_video
- play(current_frame, action, current_frames=None)[source]#
Play step - generate next frame given current frame and action.
- Parameters:
current_frame (Tensor) – (B, C, H, W) current frame
action (Tensor) – (B,) latent action indices
current_frames (Tensor | None) – (B, C, T, H, W) history frames, or None for first frame
- Returns:
(B, C, H, W)
- Return type:
next_frame
- world_models.models.genie.create_genie(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, action_vocab_size=8, action_embedding_dim=32, dynamics_dim=5120, dynamics_depth=48, dynamics_num_heads=36, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
action_vocab_size (int)
action_embedding_dim (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.models.genie.create_genie_small(num_frames=16, image_size=64, use_bfloat16=False, action_pooling='mean', window_attention_heads=1)[source]#
Create a smaller Genie model for development/testing.
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- world_models.models.genie.create_genie_large(num_frames=16, image_size=64, use_bfloat16=True, action_pooling='mean', window_attention_heads=1)[source]#
Create the full 11B parameter Genie model (approximate).
- Parameters:
num_frames (int)
image_size (int)
use_bfloat16 (bool)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- class world_models.models.latent_action_model.LatentActionModel(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
ModuleLatent Action Model (LAM) for unsupervised action learning.
Learns discrete latent actions from unlabeled video frames using a VQ-VAE based objective. The model infers latent actions between frames that encode the most meaningful changes for future frame prediction.
Based on Genie paper - learns actions without action labels from Internet videos.
Components: - Encoder: Takes all previous frames x1:t and next frame x_t+1 → outputs latent actions - Decoder: Takes previous frames x1:t-1 and latent actions a1:t-1 → predicts next frame x_t
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- encode(x_prev, x_next)[source]#
Encode frames to latent actions.
- Parameters:
x_prev (Tensor) – Previous frames (B, C, T, H, W)
x_next (Tensor) – Next frame (B, C, H, W)
- Returns:
Discrete latent action indices (B, T) z_q: Quantized embeddings (B, T, embedding_dim)
- Return type:
latent_actions
- world_models.models.latent_action_model.create_latent_action_model(num_frames=16, image_size=64, in_channels=3, encoder_dim=256, decoder_dim=512, encoder_depth=4, decoder_depth=4, num_heads=8, patch_size=16, vocab_size=8, embedding_dim=32, action_pooling='mean', window_attention_heads=1)[source]#
Factory function to create a Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- Return type:
- class world_models.models.dynamics_model.MaskGITSampler(num_steps=25, temperature=2.0, mask_schedule='cosine')[source]#
Bases:
objectMaskGIT sampling for token-based video generation.
Uses iterative refinement with a mask schedule to progressively reveal tokens during generation.
- Parameters:
num_steps (int)
temperature (float)
mask_schedule (str)
- class world_models.models.dynamics_model.DynamicsModel(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=True)[source]#
Bases:
ModuleDynamics Model for action-controllable video generation.
A decoder-only transformer that predicts future frame tokens given past frame tokens and latent actions. Uses MaskGIT for training and sampling.
Based on Genie paper - uses cross-entropy loss with random masking during training, and MaskGIT iterative refinement at inference.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
gradient_checkpointing (bool)
- forward(video_tokens, actions, mask_prob=0.0)[source]#
Forward pass for training.
- Parameters:
video_tokens (Tensor) – (B, T, H*W) - token indices for frames 1 to T
actions (Tensor) – (B, T) - latent action indices for frames 1 to T
mask_prob (float) – Probability of masking input tokens (Bernoulli 0.5-1.0)
- Returns:
(B, T, H*W, vocab_size)
- Return type:
logits
- sample(prompt_tokens, prompt_actions, num_frames, sampler=None)[source]#
Sample future frames using MaskGIT.
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
sampler (MaskGITSampler | None) – MaskGIT sampler instance
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- autoregressive_sample(prompt_tokens, prompt_actions, num_frames, temperature=1.0)[source]#
Simple autoregressive sampling (token by token).
- Parameters:
prompt_tokens (Tensor) – (B, T_prompt, N) - starting frame tokens
prompt_actions (Tensor) – (B, T_prompt) - actions for prompt frames
num_frames (int) – Total number of frames to generate
temperature (float) – Sampling temperature
- Returns:
(B, num_frames, N)
- Return type:
generated_tokens
- world_models.models.dynamics_model.create_dynamics_model(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4)[source]#
Factory function to create a Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
- Return type:
Diffusion and DIAMOND components#
Key classes: DDPM, DiT, DiffusionUNet, EDMPreconditioner, EulerSampler, RewardTerminationModel, and ActorCriticNetwork.
Diffusion sub-module - Diffusion model components for world models.
- Exported Components:
DiT: Diffusion Transformer model
PatchEmbed: Image patch embedding
PatchUnEmbed: Patch unembedding (decode tokens to image)
DDPM: Denoising Diffusion Probabilistic Model implementation
ActorCriticNetwork: DIAMOND actor-critic network
RewardTerminationModel: Reward/termination prediction model
sinusoidal_time_embedding: Time embedding for diffusion models
- class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end)[source]#
Bases:
ModuleUtility module implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- Parameters:
timesteps (int)
beta_start (float)
beta_end (float)
- q_sample(x_start, t, noise=None)[source]#
- Parameters:
x_start (Tensor)
t (Tensor)
noise (Tensor | None)
- Return type:
Tensor
- world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]#
Create sinusoidal timestep embeddings for diffusion conditioning.
This function generates positional-style embeddings for diffusion timesteps, following the same pattern as transformer positional encodings. The embeddings encode the noise level (t) and are used to condition the diffusion model.
- Math:
embedding[t] = [sin(t/10000^(2i/d)), cos(t/10000^(2i/d))] for i in [0, d/2)
- Parameters:
timesteps (Tensor) – Tensor of integer timesteps, shape (B,) or (B, 1)
dim (int) – Embedding dimension (must be even)
- Returns:
Tensor of shape (B, dim) with sinusoidal embeddings
- Return type:
Tensor
- Usage with DiT:
t = torch.tensor([0, 500, 1000]) # Timesteps emb = sinusoidal_time_embedding(t, dim=256) # (3, 256)
# Condition the model: # - Add to timestep embedding to MLP input # - Use AdaLN for adaptive normalization
- class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
Used in Vision Transformers (ViT) and DiT to convert 2D images into sequences of token embeddings that can be processed by transformers.
- Process:
Conv2d with kernel_size=stride=patch_size extracts non-overlapping patches
Each patch is projected to embed_dim via linear layer (Conv2d)
Learnable positional embeddings are added for spatial information
Input: (B, C, H, W) images Output: (B, N, embed_dim) where N = (H/patch_size) * (W/patch_size)
- Parameters:
img_size (int) – Image size (assumes square), e.g., 32 for CIFAR
patch_size (int) – Size of each patch (typically 4, 8, or 16)
in_channels (int) – Number of input channels (3 for RGB)
embed_dim (int) – Output dimension for each patch token
- Usage with DiT:
patch_embed = PatchEmbed(img_size=32, patch_size=4, in_channels=3, embed_dim=256) tokens = patch_embed(images) # (B, 64, 256) for 32x32 image with patch_size=4
- class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- Parameters:
img_size (int)
patch_size (int)
embed_dim (int)
out_channels (int)
- class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]#
Bases:
ModuleConditioned transformer block used inside the DiT backbone.
Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.
- Parameters:
d_model (int)
n_heads (int)
mlp_ratio (float)
drop (float)
t_dim (int)
- class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- Parameters:
img_size (int)
patch_size (int)
in_channels (int)
d_model (int)
depth (int)
heads (int)
drop (float)
t_dim (int)
- classmethod from_config(config=None, **overrides)[source]#
Build DiT from a config object, dict, YAML file, or YAML string.
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, map_location=None, **overrides)[source]#
Load DiT weights from a local path/directory or HF Hub.
- save_pretrained(path)[source]#
Save DiT weights and config in a from_pretrained-compatible format.
- Parameters:
path (str | Path)
- Return type:
None
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
- Parameters:
epochs (int)
dataset (Any)
batch_size (int)
lr (float)
img_size (int)
channels (int)
patch (int)
width (int)
depth (int)
heads (int)
drop (float)
timesteps (int)
beta_start (float)
beta_end (float)
ema (bool)
ema_decay (float)
workdir (str)
root_path (str)
image_folder (str | None)
crop_size (int)
download (bool)
copy_data (bool)
subset_file (str | None)
val_split (float | None)
- Return type:
None
- world_models.models.diffusion.DiT.create_dit(config=None, **overrides)[source]#
Create a
DiTfrom aDiTConfigor keyword overrides.The public factory API works with config objects, while
DiTitself has a compact constructor. This adapter keeps the lower-level model constructor unchanged and maps the public config fields onto the expected arguments.- Parameters:
config (Any)
overrides (Any)
- Return type:
- class world_models.models.diffusion.diamond_diffusion.AdaptiveGroupNorm(num_groups, num_channels, cond_dim)[source]#
Bases:
ModuleAdaptive Group Normalization that conditions on actions and diffusion time.
- Parameters:
num_groups (int)
num_channels (int)
cond_dim (int)
- class world_models.models.diffusion.diamond_diffusion.ResBlock(in_channels, out_channels, cond_dim, dropout=0.0)[source]#
Bases:
ModuleResidual block with adaptive group normalization.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
dropout (float)
- class world_models.models.diffusion.diamond_diffusion.AttentionBlock(channels, cond_dim)[source]#
Bases:
ModuleSelf-attention block for U-Net.
- Parameters:
channels (int)
cond_dim (int)
- class world_models.models.diffusion.diamond_diffusion.TimestepEmbedding(dim, freq_dim=256)[source]#
Bases:
ModuleSinusoidal timestep embedding.
- Parameters:
dim (int)
freq_dim (int)
- class world_models.models.diffusion.diamond_diffusion.DownBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#
Bases:
ModuleDownsampling block for U-Net encoder.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
num_res_blocks (int)
attention (bool)
- class world_models.models.diffusion.diamond_diffusion.UpBlock(in_channels, out_channels, cond_dim, num_res_blocks=2, attention=False)[source]#
Bases:
ModuleUpsampling block for U-Net decoder with skip connections.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
num_res_blocks (int)
attention (bool)
- class world_models.models.diffusion.diamond_diffusion.DiffusionUNet(obs_channels=3, num_conditioning_frames=4, base_channels=64, channel_multipliers=(1, 1, 1, 1), num_res_blocks=2, cond_dim=256, action_dim=18)[source]#
Bases:
ModuleU-Net architecture for EDM diffusion world model. Uses frame stacking for observation conditioning and adaptive group norm for action conditioning.
- Parameters:
obs_channels (int)
num_conditioning_frames (int)
base_channels (int)
channel_multipliers (Tuple[int, ...])
num_res_blocks (int)
cond_dim (int)
action_dim (int)
- forward(x, t, obs_history, actions)[source]#
Forward pass of the diffusion model.
- Parameters:
x (Tensor) – Noised observation at timestep t [B, C, H, W]
t (Tensor) – Diffusion timestep [B]
obs_history (Tensor) – Past observations for conditioning [B, L, C, H, W]
actions (Tensor) – Past actions [B, L]
- Returns:
Predicted clean observation [B, C, H, W]
- Return type:
Tensor
- class world_models.models.diffusion.diamond_diffusion.EDMPreconditioner(sigma_data=0.5, p_mean=-0.4, p_std=1.2)[source]#
Bases:
objectEDM preconditioner following Karras et al. (2022).
- Parameters:
sigma_data (float)
p_mean (float)
p_std (float)
- get_preconditioners(sigma)[source]#
Compute EDM preconditioners for given noise levels.
- Returns:
Dictionary with c_skip, c_out, c_in, c_noise
- Parameters:
sigma (Tensor)
- Return type:
dict
- sample_noise_level(batch_size, device)[source]#
Sample noise level from log-normal distribution.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tensor
- denoise(model, x, sigma, **kwargs)[source]#
Apply EDM denoising with preconditioners.
- Parameters:
model (Module) – Diffusion model
x (Tensor) – Noised input [B, C, H, W]
sigma (Tensor) – Noise level [B]
**kwargs (Any) – Additional conditioning (obs_history, actions)
- Returns:
Denoised prediction [B, C, H, W]
- Return type:
Tensor
- class world_models.models.diffusion.diamond_diffusion.EulerSampler(sigma_min=0.002, sigma_max=80.0, rho=7, num_steps=3, edm_precond=None)[source]#
Bases:
objectEuler method sampler for reverse diffusion.
- Parameters:
sigma_min (float)
sigma_max (float)
rho (int)
num_steps (int)
edm_precond (EDMPreconditioner | None)
- sample(model, shape, device, obs_history=None, actions=None)[source]#
Generate samples using Euler method.
- Parameters:
model (Module) – Diffusion model
shape (Tuple[int, ...]) – Output shape [B, C, H, W]
device (device) – Device to run on
obs_history (Tensor | None) – Conditioning observations [B, L, C, H, W]
actions (Tensor | None) – Conditioning actions [B, L]
- Returns:
Generated samples [B, C, H, W]
- Return type:
Tensor
- class world_models.models.diffusion.reward_termination.ConvBlock(in_channels, out_channels, cond_dim, stride=2)[source]#
Bases:
ModuleConvolutional block with adaptive group normalization.
- Parameters:
in_channels (int)
out_channels (int)
cond_dim (int)
stride (int)
- class world_models.models.diffusion.reward_termination.RewardTerminationModel(obs_channels=3, action_dim=18, channels=(32, 32, 32, 32), lstm_dim=512, cond_dim=128)[source]#
Bases:
ModuleReward and termination prediction model. CNN + LSTM architecture following DIAMOND paper specifications.
- Parameters:
obs_channels (int) – Number of observation channels (3 for RGB)
action_dim (int) – Number of possible actions
channels (Tuple[int, ...]) – List of channel sizes for conv blocks
lstm_dim (int) – LSTM hidden dimension
cond_dim (int) – Conditioning dimension for adaptive norm
- forward(obs, actions, hidden_state=None)[source]#
Forward pass of reward/termination model.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
actions (Tensor) – Actions [B, T]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Reward predictions [B, T, 3] (for -1, 0, 1) termination_logits: Termination predictions [B, T, 2] hidden_state: Updated (h, c) hidden states
- Return type:
reward_logits
- predict(obs, actions, hidden_state=None)[source]#
Predict reward and termination for a single step.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
actions (Tensor) – Single action [B]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
Predicted reward classes as tensor (values -1,0,1) terminated: Predicted termination tensor (bool tensor) hidden_state: Updated (h, c) hidden states
- Return type:
reward
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class world_models.models.diffusion.reward_termination.RewardTerminationLoss[source]#
Bases:
ModuleLoss function for reward and termination prediction.
- forward(reward_logits, termination_logits, rewards, terminated)[source]#
Compute loss for reward and termination predictions.
- Parameters:
reward_logits (Tensor) – [B, T, 3]
termination_logits (Tensor) – [B, T, 2]
rewards (Tensor) – Rewards as class indices [B, T] (values -1, 0, 1 mapped to 0, 1, 2)
terminated (Tensor) – Termination flags [B, T]
- Returns:
total_loss, reward_loss, termination_loss
- Return type:
Tuple[Tensor, Tensor, Tensor]
- class world_models.models.diffusion.actor_critic.ActorCriticNetwork(obs_channels=3, action_dim=18, channels=(32, 32, 64, 64), lstm_dim=512)[source]#
Bases:
ModuleActor-Critic network for DIAMOND RL training. Shared CNN-LSTM trunk with separate policy and value heads.
- Parameters:
obs_channels (int)
action_dim (int)
channels (Tuple[int, ...])
lstm_dim (int)
- forward(obs, hidden_state=None)[source]#
Forward pass of actor-critic network.
- Parameters:
obs (Tensor) – Observations [B, T, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
- Returns:
[B, T, action_dim] values: [B, T, 1] hidden_state: (h, c)
- Return type:
policy_logits
- get_action(obs, hidden_state=None, deterministic=False)[source]#
Get action from a single observation.
- Parameters:
obs (Tensor) – Single observation [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) hidden states
deterministic (bool) – If True, take argmax; else sample
- Returns:
Selected action [B] hidden_state: (h, c)
- Return type:
action
- get_actions(obs, hidden_state=None, deterministic=False)[source]#
Batched version of get_action.
- Parameters:
obs (Tensor) – Tensor of shape [B, C, H, W]
hidden_state (Tuple[Tensor, Tensor] | None) – Optional LSTM hidden state tuple matching batch size
deterministic (bool) – If True, take argmax; else sample from policy
- Returns:
LongTensor of shape [B] hidden_state: updated LSTM hidden state tuple
- Return type:
- get_value(obs, hidden_state=None)[source]#
Get value for a single observation.
- Parameters:
obs (Tensor)
hidden_state (Tuple[Tensor, Tensor] | None)
- Return type:
Tuple[Tensor, Tuple[Tensor, Tensor] | None]
Initialize LSTM hidden states.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
Get LSTM hidden size.
- Return type:
int
- class world_models.models.diffusion.actor_critic.RLLoss(discount_factor=0.985, lambda_returns=0.95, entropy_weight=0.001)[source]#
Bases:
ModuleRL loss functions for DIAMOND. Implements REINFORCE with value baseline and λ-returns.
- Parameters:
discount_factor (float)
lambda_returns (float)
entropy_weight (float)
- compute_lambda_returns(rewards, values, dones)[source]#
Compute λ-returns.
- Parameters:
rewards (Tensor) – [B, T]
values (Tensor) – [B, T+1]
dones (Tensor) – [B, T]
- Returns:
[B, T]
- Return type:
- policy_loss(policy_logits, actions, lambda_returns, values)[source]#
Compute policy loss with REINFORCE and entropy regularization.
- Parameters:
policy_logits (Tensor) – [B, T, A]
actions (Tensor) – [B, T]
lambda_returns (Tensor) – [B, T]
values (Tensor) – [B, T+1]
- Returns:
scalar
- Return type:
policy_loss
Vision, tokenization, and layers#
Key classes: ConvEncoder, ConvDecoder, DenseDecoder, ActionDecoder, CNNEncoder, CNNDecoder, IRISEncoder, IRISDecoder, DiscreteAutoencoder, VectorQuantizer, VectorQuantizerEMA, VideoTokenizer, MultiHeadSelfAttention, and STTransformer.
Convolutional Variational Autoencoder (ConvVAE) implementation.
This module provides the ConvVAE model architecture for encoding and decoding images in the World Models framework. The VAE uses a convolutional encoder and decoder with a variational latent space.
- class world_models.vision.VAE.ConvVAE.ConvVAEEncoder(img_channels, latent_size)[source]#
Bases:
ModuleConvolutional encoder for VAE.
This encoder takes images and produces the parameters (mean and log variance) of a Gaussian distribution in the latent space.
- Parameters:
img_channels (int)
latent_size (int)
- latent_size#
Dimensionality of the latent space.
- img_channels#
Number of input image channels.
Example
>>> encoder = ConvVAEEncoder(img_channels=3, latent_size=32) >>> mu, logsigma = encoder(images)
- class world_models.vision.VAE.ConvVAE.ConvVAEDecoder(latent_size, img_channels)[source]#
Bases:
ModuleConvolutional decoder for VAE.
This decoder takes latent vectors and reconstructs images.
- Parameters:
latent_size (int)
img_channels (int)
- latent_size#
Dimensionality of the input latent space.
- img_channels#
Number of output image channels.
- class world_models.vision.VAE.ConvVAE.ConvVAE(img_channels, latent_size)[source]#
Bases:
ModuleConvolutional Variational Autoencoder.
The ConvVAE is a generative model that encodes images into a latent distribution and reconstructs them. It uses the reparameterization trick to enable backpropagation through the sampling process.
- Parameters:
img_channels (int)
latent_size (int)
- encoder#
ConvVAEEncoder that encodes images to latent parameters.
- decoder#
ConvVAEDecoder that decodes latent vectors to images.
Example
>>> vae = ConvVAE(img_channels=3, latent_size=32) >>> recon_x, mu, logsigma = vae(images) >>> # Training loss combines reconstruction and KL divergence
- class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
This encoder transforms raw image observations (typically RGB frames from environments like Atari or DeepMind Control) into compact latent embeddings that can be processed by the RSSM (Recurrent State-Space Model).
Input: (B, C, H, W) raw images, values in [-0.5, 0.5]
Process: 4 convolutional layers with stride 2, halving spatial dimensions
Output: (B, embed_size) compact representation
The encoder uses a depth doubling pattern: 32 -> 64 -> 128 -> 256 channels. After convolutions, a fully connected layer projects from 1024 features to the desired embedding size.
Usage with Dreamer:
encoder = ConvEncoder( input_shape=(3, 64, 64), # RGB 64x64 images embed_size=256, # RSSM observation embedding size activation='relu' # Activation function ) obs_embedding = encoder(observation) # (B, 256)
- Parameters:
input_shape (tuple) – Tuple (C, H, W) for input images, typically (3, 64, 64)
embed_size (int) – Output embedding dimension, typically 256 or 1024
activation (str) – Activation function name (‘relu’, ‘elu’, ‘tanh’, etc.)
depth (int) – Base channel depth for first layer (default 32)
- class world_models.vision.dreamer_decoder.TanhBijector[source]#
Bases:
TransformBijective tanh transform for squashing Gaussian distributions to [-1, 1].
This transformation is essential for Dreamer’s action policy. Raw neural network outputs are Gaussian distributions over R^n, but actions in continuous control environments are typically bounded in [-1, 1]. The tanh bijector provides:
Bijective mapping: tanh is invertible (with atanh as inverse)
Stable log-det Jacobian: Computable for gradient-based training
Clipped actions: During inference, actions are naturally bounded
Forward: y = tanh(x)
Inverse: x = atanh(y) = 0.5 * log((1+y)/(1-y))
Log-det: log|dy/dx| = 2*(log(2) - x - softplus(-2x))
Usage with Dreamer ActionDecoder:
dist = TransformedDistribution( Normal(mean, std), TanhBijector() ) action = dist.sample() # Bounded to [-1, 1]
- Reference:
Building a Scalable Deep RL Library by Learning from Mistakes, Haarnoja et al.
- property sign: int#
- class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#
Bases:
ModuleConvolutional decoder for reconstructing observations from latent states.
Part of Dreamer’s world model, this decoder reconstructs image observations from the combined stochastic (s) and deterministic (h) RSSM states.
Input: Concatenated [stoch_state, deter_state], shape (B, stoch+deter)
Process: Dense projection + 4 transposed convolutions (upsampling 2x each)
Output: Independent Normal distribution over observation pixels
The decoder mirrors the ConvEncoder’s structure but in reverse (transposed convs instead of regular convs). This creates a symmetric autoencoder where the encoder and decoder can be trained jointly to learn compressed representations.
Returns
torch.distributions.Independent(Normal(mean, std), len(shape))allowing log_prob(observation) computation for reconstruction loss.Usage in Dreamer world model:
decoder = ConvDecoder( stoch_size=30, deter_size=200, output_shape=(3, 64, 64), # RGB images activation='relu' ) obs_dist = decoder(latent_features) # Returns distribution log_prob = obs_dist.log_prob(target_observation)
The reconstruction loss is
-log_prob(observation), which encourages the RSSM to learn states that capture observation information.- Parameters:
stoch_size (int)
deter_size (int)
output_shape (tuple[int, ...])
activation (str)
depth (int)
- class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist, num_buckets=255, symlog_range=10.0)[source]#
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
Part of Dreamer’s world model, this decoder predicts scalar quantities (rewards, values, discount factors) from RSSM latent states.
Input: [stoch_state, deter_state] concatenated, shape (B, stoch+deter)
Process: MLP with configurable layers and hidden units
Output: Predicted quantity with distribution (normal, binary, or raw)
Supports three output types: -
'normal': Gaussian distribution for regression (rewards, values) -'binary': Bernoulli distribution for binary classification (discount) -'none': Raw tensor for non-probabilistic outputsUsage:
reward_decoder = DenseDecoder( stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation='elu', dist='normal' ) reward_dist = reward_decoder(latent_features) reward_loss = -reward_dist.log_prob(target_reward)
For discount prediction (binary):
discount_decoder = DenseDecoder( stoch_size=30, deter_size=200, output_shape=(1,), n_layers=2, units=400, activation='elu', dist='binary' )
- Parameters:
stoch_size (int)
deter_size (int)
output_shape (tuple[int, ...])
n_layers (int)
units (int)
activation (str)
dist (str)
num_buckets (int)
symlog_range (float)
- class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]#
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- Parameters:
dist (Any)
samples (int)
- property name: str#
- class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
- Parameters:
action_size (int)
stoch_size (int)
deter_size (int)
n_layers (int)
units (int)
activation (str)
min_std (float)
init_std (float)
mean_scale (float)
- class world_models.vision.planet_encoder.CNNEncoder(embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) encoder for processing image inputs.
- Parameters:
embedding_size (int)
activation_function (str)
- class world_models.vision.planet_decoder.CNNDecoder(state_size, latent_size, embedding_size, activation_function='relu')[source]#
Bases:
ModuleA Convolutional Neural Network (CNN) decoder for reconstructing image outputs.
- Parameters:
state_size (int)
latent_size (int)
embedding_size (int)
activation_function (str)
- class world_models.vision.iris_encoder.IRISEncoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, in_channels=3, base_channels=64, num_residual_blocks=2, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Encoder for IRIS discrete autoencoder.
Encodes image observations into latent features, which are then quantized into discrete tokens using the VectorQuantizer.
- Architecture:
4 convolutional layers with residual blocks
Self-attention at 8x8 and 16x16 resolutions
Vector quantization to produce discrete tokens
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
in_channels (int)
base_channels (int)
num_residual_blocks (int)
frame_shape (Tuple[int, int, int])
- forward(x)[source]#
Encode images to discrete tokens.
- Parameters:
x (Tensor) – Input images (B, C, H, W) - should be 64x64
- Returns:
Quantized tokens (B, C, H’, W’) indices: Token indices (B, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- class world_models.vision.iris_encoder.ResidualBlock(channels)[source]#
Bases:
ModuleResidual block for encoder.
- Parameters:
channels (int)
- class world_models.vision.iris_encoder.SelfAttentionBlock(channels)[source]#
Bases:
ModuleSelf-attention block for encoder.
Applies spatial self-attention to capture long-range dependencies.
- Parameters:
channels (int)
- class world_models.vision.iris_decoder.IRISDecoder(vocab_size=512, embedding_dim=512, base_channels=32, out_channels=3, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCNN Decoder for IRIS discrete autoencoder.
Decodes discrete tokens back into image observations. Uses transposed convolutions to upsample from 4x4 to 64x64.
- Parameters:
vocab_size (int)
embedding_dim (int)
base_channels (int)
out_channels (int)
frame_shape (Tuple[int, int, int])
- forward(z)[source]#
Decode tokens to images.
- Parameters:
z (Tensor) – Token embeddings (B, C, H, W) - e.g., (B, 512, 4, 4)
- Returns:
Reconstructed images (B, C, H, W) - e.g., (B, 3, 64, 64)
- Return type:
reconstructed
- class world_models.vision.iris_decoder.UpsampleBlock(in_channels, mid_channels, out_channels)[source]#
Bases:
ModuleUpsampling block with optional residual connection.
- Parameters:
in_channels (int)
mid_channels (int)
out_channels (int)
- class world_models.vision.iris_decoder.ResidualBlock(channels)[source]#
Bases:
ModuleResidual block for decoder.
- Parameters:
channels (int)
- class world_models.vision.iris_decoder.DiscreteAutoencoder(vocab_size=512, tokens_per_frame=16, embedding_dim=512, base_channels=64, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleComplete Discrete Autoencoder combining encoder and decoder.
Used for training the VQVAE component of IRIS.
- Parameters:
vocab_size (int)
tokens_per_frame (int)
embedding_dim (int)
base_channels (int)
frame_shape (Tuple[int, int, int])
- class world_models.vision.vq_layer.VectorQuantizer(vocab_size=512, embedding_dim=512, commitment_weight=0.25)[source]#
Bases:
ModuleVector Quantizer for discrete autoencoder.
Implements the VQ-VAE quantization from: “Neural Discrete Representation Learning” (Van Den Oord et al., 2017)
Uses exponential moving averages for codebook updates and straight-through estimator for gradient flow.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
- class world_models.vision.vq_layer.VectorQuantizerEMA(vocab_size=512, embedding_dim=512, commitment_weight=0.25, ema_decay=0.99, epsilon=1e-05)[source]#
Bases:
ModuleVector Quantizer with Exponential Moving Average updates.
Uses EMA updates for the codebook instead of gradient-based updates, which leads to more stable training.
- Parameters:
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
ema_decay (float)
epsilon (float)
- class world_models.vision.video_tokenizer.VideoTokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, commitment_weight=0.25, use_ema=False, ema_decay=0.99)[source]#
Bases:
ModuleVideo Tokenizer using VQ-VAE with Spatiotemporal Transformer.
This is a core component of Genie (Google DeepMind, 2024), used to compress raw video frames into discrete latent tokens that can be processed by downstream models like the LatentActionModel and DynamicsModel.
The tokenizer uses Vector Quantized Variational Autoencoder (VQ-VAE) objective to learn a discrete codebook of video representations. Unlike standard VQ-VAE, this uses a Spatiotemporal (ST) Transformer in both encoder and decoder to better capture temporal dynamics in videos.
Architecture
Patch Embedding: Convert (B, C, T, H, W) video to patch tokens
Encoder ST-Transformer: Process spatial-temporal patches
Vector Quantization: Discretize continuous embeddings to codebook entries
Decoder ST-Transformer: Reconstruct video from quantized tokens
Patch Unembedding: Convert tokens back to video frames
Key Features
Causal processing: Each frame’s encoding only uses previous frames
Discrete tokens: Enables autoregressive prediction with latent actions
Memory efficient: Uses ST-Transformer instead of full ViT to reduce complexity
Usage with Genie:
tokenizer = VideoTokenizer( num_frames=16, image_size=64, patch_size=4, vocab_size=1024, embedding_dim=32 ) reconstructed, indices, loss_dict = tokenizer(video_frames) # For discrete token input to dynamics model: token_embeddings = tokenizer.decode_indices(indices)
The tokenizer is trained with VQ-VAE objective: - Reconstruction loss: MSE between input and reconstructed video - VQ loss: Commit to codebook embeddings - Commitment loss: Penalizes encoder outputs drifting from codebook
- Reference:
Genie: Generative Interactive Environments Bruce et al., Google DeepMind, 2024 - https://arxiv.org/abs/2402.15391
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
use_ema (bool)
ema_decay (float)
- encode(x)[source]#
Encode video to discrete tokens.
- Parameters:
x (Tensor) – Video tensor (B, C, T, H, W)
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim) indices: Token indices (B, T, H’, W’) vq_loss: Dictionary with VQ loss components
- Return type:
z_q
- decode_indices(indices)[source]#
Decode token indices to embeddings for video frames.
- Parameters:
indices (Tensor) – Token indices (B, T, H’, W’) or (B, T, N) where N = H’ x W’
- Returns:
Quantized embeddings (B, T, H’, W’, embedding_dim)
- Return type:
z_q
- world_models.vision.video_tokenizer.create_video_tokenizer(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False)[source]#
Factory function to create a Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
- Return type:
Blocks sub-module - Transformer blocks and attention mechanisms.
- Exported Components:
- Transformers:
STTransformer: Spatiotemporal Transformer for video processing
STSpatialAttention: Spatial attention layer
STTemporalAttention: Temporal attention layer
STTransformerBlock: Combined spatiotemporal transformer block
STBlock: Backwards-compatible alias for STTransformerBlock
- Attention:
MultiHeadSelfAttention: Multi-head self-attention
MultiHeadAttention: Backwards-compatible alias for MultiHeadSelfAttention
Attention: Attention mechanism
- Normalization:
RMSNorm: Root Mean Square Layer Normalization
AdaLNNormalization: Adaptive Layer Normalization
- class world_models.blocks.mhsa.MultiHeadSelfAttention(d, n_heads=2)[source]#
Bases:
ModuleMulti-head scaled dot-product self-attention over sequence tokens.
This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block.
- Parameters:
d (int)
n_heads (int)
- class world_models.blocks.st_transformer.STSpatialAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
ModuleSpatial attention layer for spatiotemporal transformer.
Processes video tokens by attending over spatial positions (H*W) within each time step independently. Captures within-frame spatial relationships.
Input: (B, T, N, C) – B batches, T time steps, N spatial positions (H*W), C channels
Output: (B, T, N, C) – Same shape, spatially attended features
Architecture
QKV projection: Linear(dim, dim*3)
Reshape to multi-head attention format
Fused scaled dot-product attention (FlashAttention on supported GPUs)
Output projection
Applied to video tokens of shape (B, T, N, C) to capture within-frame spatial structure (e.g., object positions).
- Parameters:
dim (int)
num_heads (int)
qkv_bias (bool)
qk_scale (float | None)
attn_drop (float)
proj_drop (float)
- class world_models.blocks.st_transformer.STTemporalAttention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
ModuleTemporal attention layer with causal masking for spatiotemporal transformer.
Processes video tokens by attending over time steps (T) across all spatial positions. Uses causal masking to ensure each frame only attends to previous frames (important for autoregressive video generation).
Input: (B, T, N, C) – B batches, T time steps, N spatial positions, C channels
Output: (B, T, N, C) – Same shape, temporally attended features
Causal masking
Frame t can only attend to frames 0…t-1
Prevents information leakage from future frames
Essential for autoregressive video generation models
Applied after STSpatialAttention to model temporal dynamics in the Genie VideoTokenizer.
- Parameters:
dim (int)
num_heads (int)
qkv_bias (bool)
qk_scale (float | None)
attn_drop (float)
proj_drop (float)
- causal_mask: Tensor#
- class world_models.blocks.st_transformer.STMLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#
Bases:
ModuleMLP for ST-Transformer block.
- Parameters:
in_features (int)
hidden_features (int | None)
out_features (int | None)
act_layer (type[Module])
drop (float)
- class world_models.blocks.st_transformer.STTransformerBlock(dim, num_heads=8, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleCombined spatiotemporal transformer block with interleaved attention.
A single block applies:
Spatial attention (within each time frame)
Temporal attention (across frames with causal mask)
MLP projection
The order is: x -> + SpatialAttn -> + TemporalAttn -> + MLP -> x
This interleaved design captures both spatial structure and temporal dynamics efficiently, used in Genie’s VideoTokenizer and DynamicsModel.
- Parameters:
dim (int) – Feature dimension (must match patch embedding dimension)
num_heads (int) – Number of attention heads
mlp_ratio (float) – MLP hidden dim = dim * mlp_ratio
drop (float) – Dropout rates
attn_drop (float) – Dropout rates
drop_path (float) – Stochastic depth rate for drop path regularization
norm_layer (type[Module]) – Normalization layer class (default: nn.LayerNorm)
qkv_bias (bool)
qk_scale (float | None)
act_layer (type[Module])
Usage in Genie:
# VideoTokenizer encoder (12 layers) encoder = STTransformer( num_frames=16, num_patches_per_frame=256, dim=512, depth=12, num_heads=16 ) encoded = encoder(tokens) # (B, T*N, C) # Dynamics model decoder (24 layers) decoder = STTransformer( num_frames=16, num_patches_per_frame=256, dim=1024, depth=24, num_heads=16 ) decoded = decoder(tokens)
- class world_models.blocks.st_transformer.DropPath(drop_prob=0.0)[source]#
Bases:
ModuleDrop paths (Stochastic Depth) per sample.
- Parameters:
drop_prob (float)
- class world_models.blocks.st_transformer.STTransformer(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, gradient_checkpointing=False)[source]#
Bases:
ModuleSpatiotemporal Transformer for video modeling.
Contains L spatiotemporal blocks with interleaved spatial and temporal attention.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
qk_scale (float | None)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
norm_layer (type[Module])
gradient_checkpointing (bool)
- world_models.blocks.st_transformer.create_st_transformer(num_frames=16, patch_size=4, img_size=64, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, gradient_checkpointing=False)[source]#
Factory function to create an ST-Transformer.
- Parameters:
num_frames (int)
patch_size (int)
img_size (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
gradient_checkpointing (bool)
- Return type:
Configuration objects#
Lazy config exports.
Configuration modules can have optional training dependencies, so the package initializer avoids importing every config eagerly.
- class world_models.configs.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#
Bases:
SerializableConfigMixinConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- Parameters:
env_backend (str)
env (str)
env_instance (Any)
image_size (tuple[int, int])
gym_render_mode (str)
dmlab_action_repeat (int)
dmlab_action_set (Any)
dmlab_observations (Any)
dmlab_config (Any)
dmlab_renderer (str)
procgen_distribution_mode (str)
procgen_num_levels (int)
procgen_start_level (Any)
mujoco_xml_path (Any)
mujoco_xml_string (Any)
mujoco_binary_path (Any)
mujoco_camera (Any)
mujoco_frame_skip (int)
mujoco_reset_noise_scale (float)
brax_backend (str)
brax_jit (bool)
brax_auto_reset (bool)
brax_suppress_warp_warnings (bool)
unity_file_name (Any)
unity_behavior_name (Any)
unity_worker_id (int)
unity_base_port (int)
unity_no_graphics (bool)
unity_time_scale (float)
unity_quality_level (int)
algo (str)
exp_name (str)
train (bool)
evaluate (bool)
seed (int)
no_gpu (bool)
max_episode_length (int)
buffer_size (int)
time_limit (int)
cnn_activation_function (str)
dense_activation_function (str)
obs_embed_size (int)
num_units (int)
deter_size (int)
stoch_size (int)
action_repeat (int)
action_noise (float)
total_steps (int)
seed_steps (int)
update_steps (int)
collect_steps (int)
batch_size (int)
train_seq_len (int)
imagine_horizon (int)
use_disc_model (bool)
free_nats (float)
discount (float)
td_lambda (float)
kl_loss_coeff (float)
kl_alpha (float)
disc_loss_coeff (float)
num_buckets (int)
symlog_range (float)
model_learning_rate (float)
actor_learning_rate (float)
value_learning_rate (float)
adam_epsilon (float)
grad_clip_norm (float)
use_amp (bool)
test (bool)
test_interval (int)
test_episodes (int)
scalar_freq (int)
log_video_freq (int)
max_videos_to_save (int)
video_format (str)
video_fps (int)
checkpoint_interval (int)
checkpoint_path (str)
restore (bool)
experience_replay (str)
render (bool)
enable_wandb (bool)
wandb_api_key (str)
wandb_project (str)
wandb_entity (str)
log_dir (str)
logdir (Any)
data_dir (Any)
log_level (str)
log_file (Any)
enable_tensorboard (bool)
enable_console_metrics (bool)
enable_jsonl (bool)
jsonl_filename (str)
log_system_stats_freq (int)
detect_anomaly (bool)
- env_backend: str = 'dmc'#
- env: str = 'walker-walk'#
- env_instance: Any = None#
- image_size: tuple[int, int] = (64, 64)#
- gym_render_mode: str = 'rgb_array'#
- dmlab_action_repeat: int = 4#
- dmlab_action_set: Any = None#
- dmlab_observations: Any = None#
- dmlab_config: Any = None#
- dmlab_renderer: str = 'hardware'#
- procgen_distribution_mode: str = 'easy'#
- procgen_num_levels: int = 0#
- procgen_start_level: Any = None#
- mujoco_xml_path: Any = None#
- mujoco_xml_string: Any = None#
- mujoco_binary_path: Any = None#
- mujoco_camera: Any = None#
- mujoco_frame_skip: int = 1#
- mujoco_reset_noise_scale: float = 0.0#
- brax_backend: str = 'generalized'#
- brax_jit: bool = True#
- brax_auto_reset: bool = False#
- brax_suppress_warp_warnings: bool = True#
- unity_file_name: Any = None#
- unity_behavior_name: Any = None#
- unity_worker_id: int = 0#
- unity_base_port: int = 5005#
- unity_no_graphics: bool = True#
- unity_time_scale: float = 20.0#
- unity_quality_level: int = 1#
- algo: str = 'Dreamerv1'#
- exp_name: str = 'lr1e-3'#
- train: bool = True#
- evaluate: bool = False#
- seed: int = 1#
- no_gpu: bool = False#
- max_episode_length: int = 1000#
- buffer_size: int = 800000#
- time_limit: int = 1000#
- cnn_activation_function: str = 'relu'#
- dense_activation_function: str = 'elu'#
- obs_embed_size: int = 1024#
- num_units: int = 400#
- deter_size: int = 200#
- stoch_size: int = 30#
- action_repeat: int = 2#
- action_noise: float = 0.3#
- total_steps: int = 5000000#
- seed_steps: int = 5000#
- update_steps: int = 100#
- collect_steps: int = 1000#
- batch_size: int = 50#
- train_seq_len: int = 50#
- imagine_horizon: int = 15#
- use_disc_model: bool = False#
- free_nats: float = 3.0#
- discount: float = 0.99#
- td_lambda: float = 0.95#
- kl_loss_coeff: float = 1.0#
- kl_alpha: float = 0.8#
- disc_loss_coeff: float = 10.0#
- num_buckets: int = 255#
- symlog_range: float = 10.0#
- model_learning_rate: float = 0.0006#
- actor_learning_rate: float = 8e-05#
- value_learning_rate: float = 8e-05#
- adam_epsilon: float = 1e-07#
- grad_clip_norm: float = 100.0#
- use_amp: bool = True#
- test: bool = False#
- test_interval: int = 10000#
- test_episodes: int = 10#
- scalar_freq: int = 1000#
- log_video_freq: int = -1#
- max_videos_to_save: int = 2#
- video_format: str = 'gif'#
- video_fps: int = 20#
- checkpoint_interval: int = 10000#
- checkpoint_path: str = ''#
- restore: bool = False#
- experience_replay: str = ''#
- render: bool = False#
- enable_wandb: bool = False#
- wandb_api_key: str = ''#
- wandb_project: str = 'torchwm'#
- wandb_entity: str = ''#
- log_dir: str = 'runs'#
- logdir: Any = None#
- data_dir: Any = None#
- log_level: str = 'INFO'#
- log_file: Any = None#
- enable_tensorboard: bool = False#
- enable_console_metrics: bool = True#
- enable_jsonl: bool = True#
- jsonl_filename: str = 'metrics.jsonl'#
- log_system_stats_freq: int = 1000#
- detect_anomaly: bool = False#
- class world_models.configs.JEPAConfig[source]#
Bases:
SerializableConfigMixinMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- classmethod from_dict(values)[source]#
Load flat field values or the nested trainer dictionary.
- Parameters:
values (Dict[str, Any])
- Return type:
- class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
SerializableConfigMixinDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via
__getattr__andget_dit_config().- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- world_models.configs.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
Both UPPER_CASE and snake_case override keys are accepted.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)
- Parameters:
overrides (Any)
- Return type:
- class world_models.configs.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
SerializableConfigMixin- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
data_loader_num_workers (int)
pin_memory (bool)
persistent_workers (bool)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
use_amp (bool)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- data_loader_num_workers: int = 4#
- pin_memory: bool = True#
- persistent_workers: bool = True#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- use_amp: bool = True#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
- class world_models.configs.IRISConfig[source]#
Bases:
SerializableConfigMixinConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class world_models.configs.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.configs.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.configs.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.configs.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
SerializableConfigMixinConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class world_models.configs.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
SerializableConfigMixinConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class world_models.configs.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
Configuration classes for World Models training.
- class world_models.configs.wm_config.WMVAEConfig(config_dict)[source]#
Bases:
objectConfiguration class for Variational Autoencoder (VAE) training.
This class manages all hyperparameters and settings for training a ConvVAE model on observation data. It provides validation and dictionary conversion utilities.
- Parameters:
config_dict (dict)
- height#
Height of input images (pixels).
- width#
Width of input images (pixels).
- device#
Device to train on (‘cpu’ or ‘cuda’).
- train_batch_size#
Number of samples per training batch.
- num_epochs#
Total number of training epochs.
- latent_size#
Dimensionality of the VAE latent space.
- data_dir#
Path to the dataset directory.
- learning_rate#
Initial learning rate for optimizer.
- logdir#
Directory for saving logs and checkpoints.
- noreload#
If True, skip loading existing checkpoints.
- nosamples#
If True, skip saving sample images during training.
- scheduler_patience#
Epochs to wait before reducing learning rate.
- scheduler_factor#
Multiplicative factor for learning rate reduction.
- early_stopping_patience#
Epochs to wait before early stopping.
- sample_interval#
Epoch interval for saving sample images.
- extra#
Dictionary for additional custom parameters.
Example
>>> config = WMVAEConfig({ ... 'height': 64, ... 'width': 64, ... 'latent_size': 32, ... 'logdir': 'results', ... }) >>> config.latent_size 32
- class world_models.configs.wm_config.WMMDNRNNConfig(config_dict)[source]#
Bases:
objectConfiguration class for Mixture Density Recurrent Neural Network (MDRNN) training.
This class manages all hyperparameters and settings for training an MDRNN model on sequence data. It provides validation and dictionary conversion utilities.
- Parameters:
config_dict (dict)
- latent_size#
Dimensionality of the latent space from VAE.
- action_size#
Dimensionality of action space.
Number of hidden units in RNN.
- gmm_components#
Number of Gaussian mixture components.
- device#
Device to train on (‘cpu’ or ‘cuda’).
- batch_size#
Number of sequences per batch.
- seq_len#
Length of each sequence.
- num_epochs#
Total number of training epochs.
- data_dir#
Path to the dataset directory.
- learning_rate#
Initial learning rate for optimizer.
- logdir#
Directory for saving logs and checkpoints.
- noreload#
If True, skip loading existing checkpoints.
- include_reward#
If True, include reward prediction in loss.
- scheduler_patience#
Epochs to wait before reducing learning rate.
- scheduler_factor#
Multiplicative factor for learning rate reduction.
- early_stopping_patience#
Epochs to wait before early stopping.
- extra#
Dictionary for additional custom parameters.
Example
>>> config = WMMDNRNNConfig({ ... 'latent_size': 32, ... 'action_size': 3, ... 'hidden_size': 256, ... 'gmm_components': 5, ... }) >>> config.hidden_size 256
- class world_models.configs.wm_config.WMControllerConfig(config_dict)[source]#
Bases:
objectConfiguration class for Controller training with CMA-ES.
This class manages hyperparameters for training a linear controller using Covariance Matrix Adaptation Evolution Strategy (CMA-ES).
- Parameters:
config_dict (dict)
- latent_size#
Dimensionality of latent state from VAE.
Dimensionality of RSSM hidden state.
- action_size#
Dimensionality of action space.
- logdir#
Directory for saving logs and checkpoints.
- n_samples#
Number of samples used to obtain return estimate.
- pop_size#
Population size for CMA-ES.
- target_return#
Stop once the return gets above this threshold.
- max_workers#
Maximum number of workers for parallel evaluation.
- display#
If True, show progress bars during training.
- time_limit#
Maximum steps per episode.
- class world_models.configs.dreamer_config.DreamerConfig(env_backend='dmc', env='walker-walk', env_instance=None, image_size=(64, 64), gym_render_mode='rgb_array', dmlab_action_repeat=4, dmlab_action_set=None, dmlab_observations=None, dmlab_config=None, dmlab_renderer='hardware', procgen_distribution_mode='easy', procgen_num_levels=0, procgen_start_level=None, mujoco_xml_path=None, mujoco_xml_string=None, mujoco_binary_path=None, mujoco_camera=None, mujoco_frame_skip=1, mujoco_reset_noise_scale=0.0, brax_backend='generalized', brax_jit=True, brax_auto_reset=False, brax_suppress_warp_warnings=True, unity_file_name=None, unity_behavior_name=None, unity_worker_id=0, unity_base_port=5005, unity_no_graphics=True, unity_time_scale=20.0, unity_quality_level=1, algo='Dreamerv1', exp_name='lr1e-3', train=True, evaluate=False, seed=1, no_gpu=False, max_episode_length=1000, buffer_size=800000, time_limit=1000, cnn_activation_function='relu', dense_activation_function='elu', obs_embed_size=1024, num_units=400, deter_size=200, stoch_size=30, action_repeat=2, action_noise=0.3, total_steps=5000000, seed_steps=5000, update_steps=100, collect_steps=1000, batch_size=50, train_seq_len=50, imagine_horizon=15, use_disc_model=False, free_nats=3.0, discount=0.99, td_lambda=0.95, kl_loss_coeff=1.0, kl_alpha=0.8, disc_loss_coeff=10.0, num_buckets=255, symlog_range=10.0, model_learning_rate=0.0006, actor_learning_rate=8e-05, value_learning_rate=8e-05, adam_epsilon=1e-07, grad_clip_norm=100.0, use_amp=True, test=False, test_interval=10000, test_episodes=10, scalar_freq=1000, log_video_freq=-1, max_videos_to_save=2, video_format='gif', video_fps=20, checkpoint_interval=10000, checkpoint_path='', restore=False, experience_replay='', render=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', log_dir='runs', logdir=None, data_dir=None, log_level='INFO', log_file=None, enable_tensorboard=False, enable_console_metrics=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', log_system_stats_freq=1000, detect_anomaly=False)[source]#
Bases:
SerializableConfigMixinConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/DMLab/Gym/MuJoCo/Robotics/Unity/Brax), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- Parameters:
env_backend (str)
env (str)
env_instance (Any)
image_size (tuple[int, int])
gym_render_mode (str)
dmlab_action_repeat (int)
dmlab_action_set (Any)
dmlab_observations (Any)
dmlab_config (Any)
dmlab_renderer (str)
procgen_distribution_mode (str)
procgen_num_levels (int)
procgen_start_level (Any)
mujoco_xml_path (Any)
mujoco_xml_string (Any)
mujoco_binary_path (Any)
mujoco_camera (Any)
mujoco_frame_skip (int)
mujoco_reset_noise_scale (float)
brax_backend (str)
brax_jit (bool)
brax_auto_reset (bool)
brax_suppress_warp_warnings (bool)
unity_file_name (Any)
unity_behavior_name (Any)
unity_worker_id (int)
unity_base_port (int)
unity_no_graphics (bool)
unity_time_scale (float)
unity_quality_level (int)
algo (str)
exp_name (str)
train (bool)
evaluate (bool)
seed (int)
no_gpu (bool)
max_episode_length (int)
buffer_size (int)
time_limit (int)
cnn_activation_function (str)
dense_activation_function (str)
obs_embed_size (int)
num_units (int)
deter_size (int)
stoch_size (int)
action_repeat (int)
action_noise (float)
total_steps (int)
seed_steps (int)
update_steps (int)
collect_steps (int)
batch_size (int)
train_seq_len (int)
imagine_horizon (int)
use_disc_model (bool)
free_nats (float)
discount (float)
td_lambda (float)
kl_loss_coeff (float)
kl_alpha (float)
disc_loss_coeff (float)
num_buckets (int)
symlog_range (float)
model_learning_rate (float)
actor_learning_rate (float)
value_learning_rate (float)
adam_epsilon (float)
grad_clip_norm (float)
use_amp (bool)
test (bool)
test_interval (int)
test_episodes (int)
scalar_freq (int)
log_video_freq (int)
max_videos_to_save (int)
video_format (str)
video_fps (int)
checkpoint_interval (int)
checkpoint_path (str)
restore (bool)
experience_replay (str)
render (bool)
enable_wandb (bool)
wandb_api_key (str)
wandb_project (str)
wandb_entity (str)
log_dir (str)
logdir (Any)
data_dir (Any)
log_level (str)
log_file (Any)
enable_tensorboard (bool)
enable_console_metrics (bool)
enable_jsonl (bool)
jsonl_filename (str)
log_system_stats_freq (int)
detect_anomaly (bool)
- env_backend: str = 'dmc'#
- env: str = 'walker-walk'#
- env_instance: Any = None#
- image_size: tuple[int, int] = (64, 64)#
- gym_render_mode: str = 'rgb_array'#
- dmlab_action_repeat: int = 4#
- dmlab_action_set: Any = None#
- dmlab_observations: Any = None#
- dmlab_config: Any = None#
- dmlab_renderer: str = 'hardware'#
- procgen_distribution_mode: str = 'easy'#
- procgen_num_levels: int = 0#
- procgen_start_level: Any = None#
- mujoco_xml_path: Any = None#
- mujoco_xml_string: Any = None#
- mujoco_binary_path: Any = None#
- mujoco_camera: Any = None#
- mujoco_frame_skip: int = 1#
- mujoco_reset_noise_scale: float = 0.0#
- brax_backend: str = 'generalized'#
- brax_jit: bool = True#
- brax_auto_reset: bool = False#
- brax_suppress_warp_warnings: bool = True#
- unity_file_name: Any = None#
- unity_behavior_name: Any = None#
- unity_worker_id: int = 0#
- unity_base_port: int = 5005#
- unity_no_graphics: bool = True#
- unity_time_scale: float = 20.0#
- unity_quality_level: int = 1#
- algo: str = 'Dreamerv1'#
- exp_name: str = 'lr1e-3'#
- train: bool = True#
- evaluate: bool = False#
- seed: int = 1#
- no_gpu: bool = False#
- max_episode_length: int = 1000#
- buffer_size: int = 800000#
- time_limit: int = 1000#
- cnn_activation_function: str = 'relu'#
- dense_activation_function: str = 'elu'#
- obs_embed_size: int = 1024#
- num_units: int = 400#
- deter_size: int = 200#
- stoch_size: int = 30#
- action_repeat: int = 2#
- action_noise: float = 0.3#
- total_steps: int = 5000000#
- seed_steps: int = 5000#
- update_steps: int = 100#
- collect_steps: int = 1000#
- batch_size: int = 50#
- train_seq_len: int = 50#
- imagine_horizon: int = 15#
- use_disc_model: bool = False#
- free_nats: float = 3.0#
- discount: float = 0.99#
- td_lambda: float = 0.95#
- kl_loss_coeff: float = 1.0#
- kl_alpha: float = 0.8#
- disc_loss_coeff: float = 10.0#
- num_buckets: int = 255#
- symlog_range: float = 10.0#
- model_learning_rate: float = 0.0006#
- actor_learning_rate: float = 8e-05#
- value_learning_rate: float = 8e-05#
- adam_epsilon: float = 1e-07#
- grad_clip_norm: float = 100.0#
- use_amp: bool = True#
- test: bool = False#
- test_interval: int = 10000#
- test_episodes: int = 10#
- scalar_freq: int = 1000#
- log_video_freq: int = -1#
- max_videos_to_save: int = 2#
- video_format: str = 'gif'#
- video_fps: int = 20#
- checkpoint_interval: int = 10000#
- checkpoint_path: str = ''#
- restore: bool = False#
- experience_replay: str = ''#
- render: bool = False#
- enable_wandb: bool = False#
- wandb_api_key: str = ''#
- wandb_project: str = 'torchwm'#
- wandb_entity: str = ''#
- log_dir: str = 'runs'#
- logdir: Any = None#
- data_dir: Any = None#
- log_level: str = 'INFO'#
- log_file: Any = None#
- enable_tensorboard: bool = False#
- enable_console_metrics: bool = True#
- enable_jsonl: bool = True#
- jsonl_filename: str = 'metrics.jsonl'#
- log_system_stats_freq: int = 1000#
- detect_anomaly: bool = False#
- class world_models.configs.jepa_config.JEPAConfig[source]#
Bases:
SerializableConfigMixinMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- classmethod from_dict(values)[source]#
Load flat field values or the nested trainer dictionary.
- Parameters:
values (Dict[str, Any])
- Return type:
- class world_models.configs.iris_config.IRISConfig[source]#
Bases:
SerializableConfigMixinConfiguration for IRIS (Imagination with auto-Regression over an Inner Speech)
Based on paper: “Transformers are Sample-Efficient World Models” Implements discrete autoencoder + autoregressive Transformer for sample-efficient RL.
- class world_models.configs.genie_config.GenieConfig(num_frames=8, image_size=32, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=256, action_encoder_depth=4, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinConfiguration for Genie model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 8#
- image_size: int = 32#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 256#
- action_encoder_depth: int = 4#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.configs.genie_config.GenieSmallConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=256, tokenizer_decoder_dim=512, tokenizer_encoder_depth=4, tokenizer_decoder_depth=8, tokenizer_num_heads=8, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=512, action_encoder_depth=8, action_num_heads=8, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=2, learning_rate=0.0001, weight_decay=0.0001, warmup_steps=1000, max_steps=50000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinSmall configuration for development/testing.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
tokenizer_num_heads (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_num_heads (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 256#
- tokenizer_decoder_dim: int = 512#
- tokenizer_encoder_depth: int = 4#
- tokenizer_decoder_depth: int = 8#
- tokenizer_num_heads: int = 8#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 512#
- action_encoder_depth: int = 8#
- action_num_heads: int = 8#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 2#
- learning_rate: float = 0.0001#
- weight_decay: float = 0.0001#
- warmup_steps: int = 1000#
- max_steps: int = 50000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.configs.genie_config.STTransformerConfig(num_frames=16, num_patches_per_frame=256, dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Spatiotemporal Transformer.
- Parameters:
num_frames (int)
num_patches_per_frame (int)
dim (int)
depth (int)
num_heads (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- num_patches_per_frame: int = 256#
- dim: int = 768#
- depth: int = 12#
- num_heads: int = 12#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.configs.genie_config.VideoTokenizerConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=512, decoder_dim=1024, encoder_depth=12, decoder_depth=20, num_heads=16, patch_size=4, vocab_size=1024, embedding_dim=32, use_ema=False, ema_decay=0.99, commitment_weight=0.25)[source]#
Bases:
SerializableConfigMixinConfiguration for Video Tokenizer.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
decoder_dim (int)
encoder_depth (int)
decoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
use_ema (bool)
ema_decay (float)
commitment_weight (float)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 512#
- decoder_dim: int = 1024#
- encoder_depth: int = 12#
- decoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 4#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- use_ema: bool = False#
- ema_decay: float = 0.99#
- commitment_weight: float = 0.25#
- class world_models.configs.genie_config.LatentActionModelConfig(num_frames=16, image_size=64, in_channels=3, encoder_dim=1024, encoder_depth=20, num_heads=16, patch_size=16, vocab_size=8, embedding_dim=32, commitment_weight=1.0, action_pooling='mean', window_attention_heads=1)[source]#
Bases:
SerializableConfigMixinConfiguration for Latent Action Model.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
encoder_dim (int)
encoder_depth (int)
num_heads (int)
patch_size (int)
vocab_size (int)
embedding_dim (int)
commitment_weight (float)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- encoder_dim: int = 1024#
- encoder_depth: int = 20#
- num_heads: int = 16#
- patch_size: int = 16#
- vocab_size: int = 8#
- embedding_dim: int = 32#
- commitment_weight: float = 1.0#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- class world_models.configs.genie_config.DynamicsModelConfig(num_frames=16, image_size=64, vocab_size=1024, embedding_dim=32, action_vocab_size=8, dim=5120, depth=48, num_heads=36, patch_size=4, mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0)[source]#
Bases:
SerializableConfigMixinConfiguration for Dynamics Model.
- Parameters:
num_frames (int)
image_size (int)
vocab_size (int)
embedding_dim (int)
action_vocab_size (int)
dim (int)
depth (int)
num_heads (int)
patch_size (int)
mlp_ratio (float)
qkv_bias (bool)
drop_rate (float)
attn_drop_rate (float)
drop_path_rate (float)
- num_frames: int = 16#
- image_size: int = 64#
- vocab_size: int = 1024#
- embedding_dim: int = 32#
- action_vocab_size: int = 8#
- dim: int = 5120#
- depth: int = 48#
- num_heads: int = 36#
- patch_size: int = 4#
- mlp_ratio: float = 4.0#
- qkv_bias: bool = True#
- drop_rate: float = 0.0#
- attn_drop_rate: float = 0.0#
- drop_path_rate: float = 0.0#
- class world_models.configs.dit_config.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
SerializableConfigMixinDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via
__getattr__andget_dit_config().- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- world_models.configs.dit_config.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
Both UPPER_CASE and snake_case override keys are accepted.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3)
- Parameters:
overrides (Any)
- Return type:
- class world_models.configs.diamond_config.ModelPreset(diffusion_channels, diffusion_res_blocks, diffusion_cond_dim, reward_channels, reward_lstm_dim, actor_channels, actor_lstm_dim)[source]#
Bases:
SerializableConfigMixinModel architecture preset for different hardware tiers.
- Parameters:
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
reward_channels (List[int])
reward_lstm_dim (int)
actor_channels (List[int])
actor_lstm_dim (int)
- diffusion_channels: List[int]#
- diffusion_res_blocks: int#
- diffusion_cond_dim: int#
- reward_channels: List[int]#
- reward_lstm_dim: int#
- actor_channels: List[int]#
- actor_lstm_dim: int#
- class world_models.configs.diamond_config.DiamondConfig(preset: str | None = None, game: str = 'Breakout-v5', seed: int = 0, obs_size: int = 64, frameskip: int = 4, max_noop: int = 30, terminate_on_life_loss: bool = True, reward_clip: List[int] = <factory>, num_conditioning_frames: int = 4, diffusion_channels: List[int] = <factory>, diffusion_res_blocks: int = 2, diffusion_cond_dim: int = 256, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: int = 7, p_mean: float = -0.4, p_std: float = 1.2, sampling_method: str = 'euler', num_sampling_steps: int = 3, reward_channels: List[int] = <factory>, reward_res_blocks: int = 2, reward_cond_dim: int = 128, reward_lstm_dim: int = 512, burn_in_length: int = 4, actor_channels: List[int] = <factory>, actor_res_blocks: int = 1, actor_lstm_dim: int = 512, num_epochs: int = 1000, training_steps_per_epoch: int = 400, batch_size: int = 32, environment_steps_per_epoch: int = 100, epsilon_greedy: float = 0.01, data_loader_num_workers: int = 4, pin_memory: bool = True, persistent_workers: bool = True, imagination_horizon: int = 15, discount_factor: float = 0.985, entropy_weight: float = 0.001, lambda_returns: float = 0.95, learning_rate: float = 0.0001, adam_epsilon: float = 1e-08, weight_decay_diffusion: float = 0.01, weight_decay_reward: float = 0.01, weight_decay_actor: float = 0.0, use_amp: bool = True, device: str = <factory>, log_interval: int = 10, eval_interval: int = 50, save_interval: int = 100, operator_state_dim: int = 32, operator_action_dim: int = 4)[source]#
Bases:
SerializableConfigMixin- Parameters:
preset (str | None)
game (str)
seed (int)
obs_size (int)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (List[int])
num_conditioning_frames (int)
diffusion_channels (List[int])
diffusion_res_blocks (int)
diffusion_cond_dim (int)
sigma_data (float)
sigma_min (float)
sigma_max (float)
rho (int)
p_mean (float)
p_std (float)
sampling_method (str)
num_sampling_steps (int)
reward_channels (List[int])
reward_res_blocks (int)
reward_cond_dim (int)
reward_lstm_dim (int)
burn_in_length (int)
actor_channels (List[int])
actor_res_blocks (int)
actor_lstm_dim (int)
num_epochs (int)
training_steps_per_epoch (int)
batch_size (int)
environment_steps_per_epoch (int)
epsilon_greedy (float)
data_loader_num_workers (int)
pin_memory (bool)
persistent_workers (bool)
imagination_horizon (int)
discount_factor (float)
entropy_weight (float)
lambda_returns (float)
learning_rate (float)
adam_epsilon (float)
weight_decay_diffusion (float)
weight_decay_reward (float)
weight_decay_actor (float)
use_amp (bool)
device (str)
log_interval (int)
eval_interval (int)
save_interval (int)
operator_state_dim (int)
operator_action_dim (int)
- preset: str | None = None#
- game: str = 'Breakout-v5'#
- seed: int = 0#
- obs_size: int = 64#
- frameskip: int = 4#
- max_noop: int = 30#
- terminate_on_life_loss: bool = True#
- reward_clip: List[int]#
- num_conditioning_frames: int = 4#
- diffusion_channels: List[int]#
- diffusion_res_blocks: int = 2#
- diffusion_cond_dim: int = 256#
- sigma_data: float = 0.5#
- sigma_min: float = 0.002#
- sigma_max: float = 80.0#
- rho: int = 7#
- p_mean: float = -0.4#
- p_std: float = 1.2#
- sampling_method: str = 'euler'#
- num_sampling_steps: int = 3#
- reward_channels: List[int]#
- reward_res_blocks: int = 2#
- reward_cond_dim: int = 128#
- reward_lstm_dim: int = 512#
- burn_in_length: int = 4#
- actor_channels: List[int]#
- actor_res_blocks: int = 1#
- actor_lstm_dim: int = 512#
- num_epochs: int = 1000#
- training_steps_per_epoch: int = 400#
- batch_size: int = 32#
- environment_steps_per_epoch: int = 100#
- epsilon_greedy: float = 0.01#
- data_loader_num_workers: int = 4#
- pin_memory: bool = True#
- persistent_workers: bool = True#
- imagination_horizon: int = 15#
- discount_factor: float = 0.985#
- entropy_weight: float = 0.001#
- lambda_returns: float = 0.95#
- learning_rate: float = 0.0001#
- adam_epsilon: float = 1e-08#
- weight_decay_diffusion: float = 0.01#
- weight_decay_reward: float = 0.01#
- weight_decay_actor: float = 0.0#
- use_amp: bool = True#
- device: str#
- log_interval: int = 10#
- eval_interval: int = 50#
- save_interval: int = 100#
- operator_state_dim: int = 32#
- operator_action_dim: int = 4#
Training entry points#
Complete World Model training pipeline for any Gym environment.
This script trains a complete World Model pipeline consisting of: 1. ConvVAE - Encodes observations into latent space 2. MDNRNN - Predicts future latent states given actions 3. Controller - Linear controller trained with CMA-ES
- Usage:
python train_world_model.py –env CarRacing-v2 –data_dir ./data –logdir ./results python train_world_model.py –env BipedalWalker-v3 –action_size 4 # if env loading fails
The script will: 1. Generate rollout data (if not already present) 2. Train VAE 3. Train MDNRNN 4. Train Controller
- world_models.training.train_world_model.generate_rollouts(data_dir, env_name, num_rollouts=1000, seq_len=1000, num_workers=8)[source]#
Generate random rollouts from the specified environment.
- Parameters:
data_dir (str) – Directory to save rollout files
env_name (str) – Name of the Gym environment
num_rollouts (int) – Total number of rollouts to generate
seq_len (int) – Maximum length per rollout
num_workers (int) – Number of parallel workers
- Return type:
None
- world_models.training.train_world_model.run_training_pipeline(args, action_size)[source]#
Execute the complete World Model training pipeline.
- Parameters:
args (Any)
action_size (int)
- Return type:
None
- world_models.training.train_world_model.test_trained_model(logdir, env_name, action_size, num_episodes=5)[source]#
Test the trained world model with controller in the environment.
- Parameters:
logdir (str)
env_name (str)
action_size (int)
num_episodes (int)
- Return type:
None
Training script for Convolutional Variational Autoencoder (ConvVAE).
This module provides functions to train a ConvVAE model on observation data for world model learning.
- world_models.training.train_convvae.save_checkpoint(state, is_best, filename, best_filename)[source]#
Save model checkpoint.
- Parameters:
state (dict) – Dictionary containing model state to save.
is_best (bool) – If True, also save as best checkpoint.
filename (str) – Path to save checkpoint.
best_filename (str) – Path to save best checkpoint.
- Return type:
None
- world_models.training.train_convvae.test_epoch(model, test_loader, device, loss_fn)[source]#
Run one epoch of validation.
- Parameters:
model (ConvVAE) – The VAE model to evaluate.
test_loader (DataLoader) – DataLoader for test/validation data.
device (device) – Device to run evaluation on.
loss_fn (Any) – Loss function to use.
- Returns:
Average test loss for the epoch.
- Return type:
float
- world_models.training.train_convvae.train_epoch(epoch, model, optimizer, train_loader, device, train_dataset, loss_fn, use_amp=False, scaler=None)[source]#
Run one epoch of training.
- Parameters:
epoch (int) – Current epoch number.
model (Any) – The VAE model to train.
optimizer (Any) – Optimizer for training.
train_loader (Any) – DataLoader for training data.
device (Any) – Device to run training on.
train_dataset (Any) – Training dataset (used to load next buffer if applicable).
loss_fn (Any) – Loss function to use.
use_amp (bool) – Whether to use automatic mixed precision.
scaler (GradScaler | None) – GradScaler for mixed precision training.
- Return type:
float
- world_models.training.train_convvae.train_convae(config)[source]#
Train a Convolutional VAE model.
This function trains a ConvVAE on observation data using the provided configuration. It handles data loading, model initialization, training loop, checkpointing, and sample generation.
- Parameters:
config (WMVAEConfig) – WMVAEConfig object containing all training hyperparameters.
- Return type:
None
- The training process includes:
Loading pretrained VAE if available (unless noreload is True)
Training for specified number of epochs
Validating after each epoch
Learning rate scheduling with ReduceLROnPlateau
Early stopping based on validation loss
Checkpointing best and current models
Generating sample images at specified intervals
Example
>>> config = WMVAEConfig({ ... 'height': 64, ... 'width': 64, ... 'latent_size': 32, ... 'num_epochs': 100, ... 'logdir': 'results', ... }) >>> train_convae(config)
Training script for Mixture Density Recurrent Neural Network (MDRNN).
This module provides functions to train an MDRNN model for sequence prediction in world models. The MDRNN predicts future latent states using a Gaussian Mixture Model (GMM) based on current latent states and actions.
- world_models.training.train_mdn_rnn.precompute_latents(vae_config, mdrnn_config)[source]#
Pre-compute and save VAE latents to disk for memory-efficient RNN training.
This function encodes all observations using the VAE and saves the latent representations to disk. This allows RNN training without keeping the VAE in GPU memory.
- Parameters:
vae_config (WMVAEConfig) – WMVAEConfig for loading pretrained VAE.
mdrnn_config (WMMDNRNNConfig) – WMMDNRNNConfig containing latent_size and device settings.
- Return type:
None
- world_models.training.train_mdn_rnn.save_checkpoint(state, is_best, filename, best_filename)[source]#
Save model checkpoint.
- Parameters:
state (Any) – Dictionary containing model state to save.
is_best (bool) – If True, also save as best checkpoint.
filename (str) – Path to save checkpoint.
best_filename (str) – Path to save best checkpoint.
- Return type:
None
- world_models.training.train_mdn_rnn.to_latent(vae, obs, next_obs, device, red_size=64)[source]#
Transform observations to latent space using VAE encoder.
This function encodes observations into the latent space using the VAE’s encoder network. It applies the reparameterization trick to sample from the learned latent distribution.
- Parameters:
vae (ConvVAE) – Trained VAE model with encoder.
obs (Tensor) – Batch of current observations.
next_obs (Tensor) – Batch of next observations.
device (device) – Device to run encoding on.
red_size (int) – Target size for resizing images (default: 64).
- Returns:
Tuple of (latent_obs, latent_next_obs) tensors in latent space.
- Return type:
tuple[Tensor, Tensor]
- world_models.training.train_mdn_rnn.get_loss(mdrnn, latent_obs, action, reward, terminal, latent_next_obs, include_reward, latent_size)[source]#
Compute MDRNN loss.
Computes the combined loss for the MDRNN model: - GMM loss for next latent state prediction - BCE loss for terminal state prediction - MSE loss for reward prediction (if include_reward is True)
- Parameters:
mdrnn (MDRNN) – MDRNN model.
latent_obs (Tensor) – Current latent observations.
action (Tensor) – Actions taken.
reward (Tensor) – Rewards received.
terminal (Tensor) – Terminal state flags.
latent_next_obs (Tensor) – Next latent observations (target).
include_reward (bool) – Whether to include reward prediction in loss.
latent_size (int) – Size of latent space.
- Returns:
Dictionary containing gmm, bce, mse, and total loss values.
- Return type:
dict[str, Tensor]
- world_models.training.train_mdn_rnn.data_pass(epoch, mdrnn, vae, train_loader, test_loader, optimizer, device, include_reward, test_every=10, epochs=1, use_amp=False, scaler=None, latent_size=32, max_seq_len=50, log_wandb=False, prev_val_loss=1000000.0, early_stop=None, lr_scheduler=None, batch_size=50, train=True)[source]#
Run one epoch of training or validation.
- Parameters:
epoch (int) – Current epoch number.
mdrnn (Any) – MDRNN model.
vae (Any) – VAE model for encoding observations (None if using precomputed latents).
train_loader (Any) – Training data loader.
test_loader (Any) – Test/validation data loader.
optimizer (Any) – Optimizer (used only for training).
device (Any) – Device to run on.
include_reward (bool) – Whether to include reward in loss.
latent_size (int) – Size of latent space.
batch_size (int) – Batch size.
train (bool) – If True, run training pass; otherwise run validation.
use_amp (bool) – If True, use automatic mixed precision.
scaler (Any) – GradScaler for mixed precision training.
test_every (int)
epochs (int)
max_seq_len (int)
log_wandb (bool)
prev_val_loss (float)
early_stop (Any)
lr_scheduler (Any)
- Returns:
Average loss for the epoch.
- Return type:
float
- world_models.training.train_mdn_rnn.train_mdn_rnn(vae_config, mdrnn_config, use_precomputed_latents=True, use_amp=True)[source]#
Train an MDRNN model.
This function trains an MDRNN on sequence data using the provided configurations. It loads a pretrained VAE for encoding observations into latent space, then trains the MDRNN to predict future latent states given current latent states and actions.
- Parameters:
vae_config (WMVAEConfig) – WMVAEConfig for loading pretrained VAE.
mdrnn_config (WMMDNRNNConfig) – WMMDNRNNConfig containing MDRNN training hyperparameters.
use_precomputed_latents (bool) – If True, use pre-encoded latents from disk.
use_amp (bool) – If True, use automatic mixed precision for memory efficiency.
- Return type:
None
- The training process includes:
Loading pretrained VAE from vae_config.logdir
Training for specified number of epochs
Validating after each epoch
Learning rate scheduling with ReduceLROnPlateau
Early stopping based on validation loss
Checkpointing best and current models
Example
>>> vae_config = WMVAEConfig({ ... 'height': 64, 'width': 64, 'latent_size': 32, 'logdir': 'results' ... }) >>> mdrnn_config = WMMDNRNNConfig({ ... 'latent_size': 32, 'action_size': 3, 'hidden_size': 256, ... 'gmm_components': 5, 'logdir': 'results' ... }) >>> train_mdn_rnn(vae_config, mdrnn_config)
Training a linear controller on latent + recurrent state with CMA-ES.
This module provides functions to train a linear controller using Covariance Matrix Adaptation Evolution Strategy (CMA-ES). The controller maps latent and hidden states to actions for the learned world model.
- Reference:
Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111
- world_models.training.train_controller.flatten_parameters(parameters)[source]#
- Parameters:
parameters (Any)
- Return type:
ndarray
- world_models.training.train_controller.load_parameters(params, controller)[source]#
- Parameters:
params (Any)
controller (Any)
- Return type:
None
- world_models.training.train_controller.slave_routine(p_queue, r_queue, e_queue, p_index, config, time_limit)[source]#
Worker process routine for parallel rollout evaluation.
- Parameters:
p_queue (Any) – Queue containing (s_id, parameters) to evaluate.
r_queue (Any) – Queue where to place results (s_id, reward).
e_queue (Any) – End queue - when non-empty, process terminates.
p_index (int) – Process index for GPU assignment.
config (Any) – Controller configuration (must include env_name and action_size).
time_limit (int) – Maximum steps per episode.
- Return type:
None
- world_models.training.train_controller.evaluate(solutions, results, rollouts, p_queue, r_queue)[source]#
Evaluate current controller.
- Parameters:
solutions (Any)
results (Any)
rollouts (int)
p_queue (Any)
r_queue (Any)
- Return type:
Any
- world_models.training.train_controller.train_controller(config)[source]#
Train a linear controller using CMA-ES.
- Parameters:
config (WMControllerConfig) – WMControllerConfig containing training hyperparameters, including env_name and action_size.
- Return type:
None
- The training process includes:
Setting up parallel evaluation workers (each loads VAE + MDRNN)
Running CMA-ES optimization with parallel rollout evaluation
Evaluating and saving best controller checkpoint
- world_models.training.train_jepa.main(args=None, resume_preempt=False)[source]#
Run JEPA training using a CLI argv, nested dict, or JEPAConfig instance.
This entrypoint initializes distributed context, data pipeline, masking, models, optimizers/schedulers, checkpointing, and the full epoch loop.
- Parameters:
args (Any)
resume_preempt (bool)
- Return type:
Any
- world_models.training.train_jepa.sweep_train()[source]#
Function for WandB sweep agent.
- Return type:
None
- world_models.training.train_jepa.main_from_cli(argv=None)[source]#
Compose JEPA config from YAML/dot-list overrides and launch training.
- Parameters:
argv (list[str] | None)
- Return type:
Any
- class world_models.training.train_iris.IRISTrainer(game='ALE/Pong-v5', device='cuda', seed=42, config=None)[source]#
Bases:
objectTraining loop for IRIS on Atari 100k benchmark.
- Parameters:
game (str)
device (str)
seed (int)
config (IRISConfig | None)
- preprocess_frame(frame)[source]#
Preprocess frame: resize to 64x64 and normalize.
- Parameters:
frame (ndarray)
- Return type:
ndarray
- collect_experience(num_steps, epsilon=0.01)[source]#
Collect experience from environment.
- Parameters:
num_steps (int) – Number of steps to collect
epsilon (float) – Random action probability
- Returns:
Mean episode return
- Return type:
float
- train_epoch(epoch)[source]#
Train for one epoch.
- Parameters:
epoch (int) – Current epoch number
- Returns:
Dictionary of metrics
- Return type:
dict
- get_epsilon(epoch)[source]#
Get exploration epsilon with decay.
- Parameters:
epoch (int)
- Return type:
float
- evaluate(num_episodes=100, render=False)[source]#
Evaluate agent performance.
- Parameters:
num_episodes (int) – Number of evaluation episodes
render (bool) – If True, also return video frames and per-step latent vectors
- Returns:
dict with evaluation metrics If render is True: tuple (episode_returns_array, videos_list, latents_array)
- Return type:
If render is False (default)
- world_models.training.train_iris.main(argv=None)[source]#
Run IRIS training with YAML config files and Hydra dot-list overrides.
- Parameters:
argv (list[str] | None)
- Return type:
- class world_models.training.train_genie.GenieConfig(num_frames=16, image_size=64, in_channels=3, tokenizer_vocab_size=1024, tokenizer_embedding_dim=32, tokenizer_encoder_dim=512, tokenizer_decoder_dim=1024, tokenizer_encoder_depth=12, tokenizer_decoder_depth=20, action_vocab_size=8, action_embedding_dim=32, action_encoder_dim=1024, action_encoder_depth=20, action_pooling='mean', window_attention_heads=1, dynamics_dim=512, dynamics_depth=8, dynamics_num_heads=8, batch_size=4, learning_rate=3e-05, weight_decay=0.0001, warmup_steps=5000, max_steps=125000, mask_prob_min=0.5, mask_prob_max=1.0, sample_temperature=2.0, maskgit_steps=25)[source]#
Bases:
SerializableConfigMixinConfiguration for Genie training.
- Parameters:
num_frames (int)
image_size (int)
in_channels (int)
tokenizer_vocab_size (int)
tokenizer_embedding_dim (int)
tokenizer_encoder_dim (int)
tokenizer_decoder_dim (int)
tokenizer_encoder_depth (int)
tokenizer_decoder_depth (int)
action_vocab_size (int)
action_embedding_dim (int)
action_encoder_dim (int)
action_encoder_depth (int)
action_pooling (Literal['mean', 'windowed_attention'])
window_attention_heads (int)
dynamics_dim (int)
dynamics_depth (int)
dynamics_num_heads (int)
batch_size (int)
learning_rate (float)
weight_decay (float)
warmup_steps (int)
max_steps (int)
mask_prob_min (float)
mask_prob_max (float)
sample_temperature (float)
maskgit_steps (int)
- num_frames: int = 16#
- image_size: int = 64#
- in_channels: int = 3#
- tokenizer_vocab_size: int = 1024#
- tokenizer_embedding_dim: int = 32#
- tokenizer_encoder_dim: int = 512#
- tokenizer_decoder_dim: int = 1024#
- tokenizer_encoder_depth: int = 12#
- tokenizer_decoder_depth: int = 20#
- action_vocab_size: int = 8#
- action_embedding_dim: int = 32#
- action_encoder_dim: int = 1024#
- action_encoder_depth: int = 20#
- action_pooling: Literal['mean', 'windowed_attention'] = 'mean'#
- window_attention_heads: int = 1#
- dynamics_dim: int = 512#
- dynamics_depth: int = 8#
- dynamics_num_heads: int = 8#
- batch_size: int = 4#
- learning_rate: float = 3e-05#
- weight_decay: float = 0.0001#
- warmup_steps: int = 5000#
- max_steps: int = 125000#
- mask_prob_min: float = 0.5#
- mask_prob_max: float = 1.0#
- sample_temperature: float = 2.0#
- maskgit_steps: int = 25#
- class world_models.training.train_genie.VideoDataset(video_paths, num_frames=16, image_size=64)[source]#
Bases:
DatasetDataset for video data.
- Parameters:
video_paths (list)
num_frames (int)
image_size (int)
- class world_models.training.train_genie.GenieTrainer(model, config, device=None)[source]#
Bases:
objectTrainer for Genie model.
- Parameters:
model (Module)
config (GenieConfig)
device (device | None)
- train_step(batch)[source]#
Single training step.
- Parameters:
batch (Tensor) – (B, C, T, H, W) video batch
- Returns:
Dictionary of losses
- Return type:
Dict[str, Tensor | float | None]
- validate(val_batch)[source]#
Validation step.
- Parameters:
val_batch (Tensor) – (B, C, T, H, W) validation video batch
- Returns:
Dictionary of validation metrics
- Return type:
Dict[str, Tensor]
- train(train_dataloader, val_dataloader=None, num_steps=None, log_interval=100, val_interval=1000)[source]#
Full training loop.
- Parameters:
train_dataloader (DataLoader) – Training data loader
val_dataloader (DataLoader | None) – Validation data loader (optional)
num_steps (int | None) – Number of training steps (uses config.max_steps if None)
log_interval (int) – Logging frequency
val_interval (int) – Validation frequency
- Return type:
None
- world_models.training.train_genie.create_genie_trainer(config=None, device=None)[source]#
Factory function to create Genie trainer and model.
- Parameters:
config (GenieConfig | None)
device (device | None)
- Return type:
Tuple[GenieTrainer, Module]
- world_models.training.train_genie.main(argv=None)[source]#
Console entrypoint for Genie trainer setup.
The generic
VideoDatasetin this module is intentionally abstract, so this command provides a discoverable entrypoint for inspecting defaults and constructing a trainer. Use a concrete dataset script, such asscripts/train_genie_tinyworlds.py, for end-to-end data loading.- Parameters:
argv (list[str] | None)
- Return type:
None
- world_models.training.train_planet.train(memory, rssm, optimizer, device, N=32, H=50, beta=1.0, grads=False)[source]#
Training implementation as indicated in: Learning Latent Dynamics for Planning from Pixels arXiv:1811.04551
- (a.) The Standard Variational Bound Method
using only single step predictions.
- Parameters:
memory (Any)
rssm (Any)
optimizer (Any)
device (device)
N (int)
H (int)
beta (float)
grads (bool)
- Return type:
dict
- world_models.training.train_planet.main()[source]#
Example PlaNet/RSSM training script with rollout collection and evaluation.
Builds environment/model/policy objects, iteratively trains on replayed episodes, and periodically saves videos and checkpoints.
- Return type:
None
- world_models.training.train_rssm.train_rssm(memory, model, optimizer, record_grads=True)[source]#
Train an RSSM on replayed trajectories for one optimization phase.
Samples batches from memory, computes reconstruction and KL objectives across rollout steps, and returns aggregated loss metrics.
- Parameters:
memory (Any)
model (Any)
optimizer (Any)
record_grads (bool)
- Return type:
dict
- world_models.training.train_rssm.evaluate(memory, model, path, eps)[source]#
Run one RSSM reconstruction/prediction evaluation and save visual outputs.
Decodes priors/posteriors for a sampled sequence and writes frame grids for qualitative inspection.
- Parameters:
memory (Any)
model (Any)
path (str)
eps (Any)
- Return type:
None
- world_models.training.train_rssm.main()[source]#
Standalone training loop for RSSM with generated replay fallback support.
Initializes environment/policy/memory, trains over episodes, logs metrics, and periodically evaluates and checkpoints the model.
- Return type:
None
- class world_models.training.train_diamond.DiamondAgent(config)[source]#
Bases:
objectDIAMOND: DIffusion As a Model Of eNvironment Dreams
RL agent trained entirely within a diffusion world model.
- Parameters:
config (DiamondConfig)
- classmethod from_config(config=None, **overrides)[source]#
Build a DIAMOND agent from a config object, dict, YAML file, or YAML string.
- Parameters:
config (DiamondConfig | dict | str | Path | None)
overrides (Any)
- Return type:
- classmethod from_pretrained(pretrained_model_name_or_path, *, config=None, checkpoint_filename=None, config_filename='config.yaml', repo_type=None, revision=None, **overrides)[source]#
Load a DIAMOND checkpoint from a local path/directory or HF Hub.
- Parameters:
pretrained_model_name_or_path (str | Path)
config (DiamondConfig | dict | str | Path | None)
checkpoint_filename (str | None)
config_filename (str)
repo_type (str | None)
revision (str | None)
overrides (Any)
- Return type:
- evaluate(num_episodes=1)[source]#
Evaluate the agent.
- Parameters:
num_episodes (int)
- Return type:
float
- save_checkpoint(path=None)[source]#
Save model checkpoint.
- Parameters:
path (str | PathLike | None) – Optional path where to write the checkpoint. If path is None or a bare filename, the file is written into checkpoints/diamond/<filename>. If path contains a directory component or is an absolute/relative path, it is used directly. When path is None, the legacy behavior is preserved and the checkpoint is written to checkpoints/diamond/checkpoint.pt.
- Return type:
None
- load_checkpoint(path=None)[source]#
Load model checkpoint.
- Parameters:
path (str | None) – Optional path to checkpoint. If None, the default checkpoints/diamond/checkpoint.pt is loaded. If a bare filename is provided, we try checkpoints/diamond/<filename>; if a path with directory components is provided we use it directly.
- Return type:
None
- world_models.training.train_diamond.train_diamond(game=None, seed=None, preset=None, device=None, config=None)[source]#
Train DIAMOND on a specific game or a composed experiment config.
- Parameters:
game (str | None)
seed (int | None)
preset (str | None)
device (str | None)
config (DiamondConfig | None)
- Return type:
None
- world_models.training.train_diamond.main(argv=None)[source]#
Compose DIAMOND config from YAML/dot-list overrides and launch training.
- Parameters:
argv (list[str] | None)
- Return type:
Any
- class world_models.training.rl_harness.ActorCritic(obs_shape, action_dim, hidden_dim=256)[source]#
Bases:
ModuleSimple actor-critic network for RL harness.
- Parameters:
obs_shape (tuple)
action_dim (int)
hidden_dim (int)
- class world_models.training.rl_harness.PPOTrainer(vec_env, device='cpu', lr=0.0003, gamma=0.99, gae_lambda=0.95, clip_ratio=0.2, num_epochs=10, batch_size=64, max_grad_norm=0.5, entropy_coeff=0.01, value_coeff=0.5)[source]#
Bases:
objectSimple PPO trainer for testing vectorized environments.
- Parameters:
vec_env (TorchVectorizedEnv)
device (str)
lr (float)
gamma (float)
gae_lambda (float)
clip_ratio (float)
num_epochs (int)
batch_size (int)
max_grad_norm (float)
entropy_coeff (float)
value_coeff (float)
- collect_trajectories(num_steps)[source]#
Collect trajectories using the vectorized environment.
- Parameters:
num_steps (int)
- Return type:
Dict[str, Tensor]
- compute_gae(rewards, values, dones)[source]#
Compute Generalized Advantage Estimation.
- Parameters:
rewards (Tensor)
values (Tensor)
dones (Tensor)
- Return type:
Tensor
Memory, controllers, and inference operators#
- class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#
Bases:
objectFixed-size replay buffer for Dreamer with image observations and transitions.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world-model training.
Key Features
Ring buffer with fixed capacity (FIFO eviction when full)
Stores raw uint8 images to save memory
Samples sequences (not single transitions) for temporal modeling
Validates sampled sequences don’t span episode boundaries
Memory Layout
observations: (capacity, C, H, W) uint8 images
actions: (capacity, action_dim) float32
rewards: (capacity,) float32
terminals: (capacity,) float32 (1.0 = terminal, 0.0 = continue)
Sampling Process
Random start index (avoiding episode boundaries)
Collect sequence of length seq_len with wraparound
Validate no terminal in middle of sequence
Return batch of sequences
Usage with Dreamer:
buffer = ReplayBuffer( size=100000, # Max transitions to store obs_shape=(3, 64, 64), # RGB images action_size=6, # Continuous action dim seq_len=50, # Sequence length for training batch_size=50 # Parallel sequences per batch ) # Add transitions during interaction buffer.add(obs, action, reward, done) # Sample batch for world model training obs_batch, action_batch, reward_batch, term_batch = buffer.sample()
Memory Efficiency
Uses uint8 for images (1 byte per pixel vs 4 for float32)
Sequences share observations (overlapping windows)
Configurable capacity based on available system memory
Note
The buffer stores observations as {“image”: …} dicts but returns just the image arrays for training efficiency.
- Parameters:
size (int)
obs_shape (Tuple[int, ...])
action_size (int)
seq_len (int)
batch_size (int)
- add(obs, ac, rew, done)[source]#
Add a transition to the buffer.
- Parameters:
obs (dict) – Observation dict with ‘image’ key containing the observation
ac (ndarray) – Action taken, shape (action_size,)
rew (float) – Reward received, scalar
done (float) – Terminal flag, 1.0 if episode ended, 0.0 otherwise
- Return type:
None
- class world_models.memory.dreamer_memory.Memory(capacity=10000)[source]#
Bases:
objectSimple deque-based memory for storing transitions.
Used by PlaNet for online planning. Stores recent transitions and provides random sampling for policy updates.
- Parameters:
capacity (int) – Maximum number of transitions to store
Usage:
memory = Memory(capacity=10000) memory.append(obs, action, reward, done, info) batch = random.sample(memory, batch_size=32)
- class world_models.memory.dreamer_memory.Episode(observation, action=None, reward=None, terminal=None, info=None)[source]#
Bases:
objectStores a single episode for PlaNet’s imagination and planning.
An episode is a sequence of (observation, action, reward) tuples collected during environment interaction. Episodes are used for computing returns and training value functions.
- Parameters:
obs – Initial observation
action (Any) – First action (optional)
reward (Any) – Initial reward (optional)
info (Any) – Additional info dict (optional)
observation (Any)
terminal (Any)
Usage:
episode = Episode(obs, info=info) episode.append(action, obs, reward, done, info) episodes = [episode for _ in range(num_episodes)] # Use with Planet agent for planning imag_state, imag_reward, imag_action = planet.imagine(episodes)
- class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]#
Bases:
objectRecords the agent’s interaction with the environment for a single episode.
Stores observations, actions, rewards, and terminal flags during a single trajectory. At termination, converts all lists to numpy arrays for efficient batch processing.
- x#
Observations collected during the episode.
- Type:
list or np.ndarray
- u#
Actions taken.
- Type:
list or np.ndarray
- r#
Rewards received.
- Type:
list or np.ndarray
- t#
Terminal flags (0.0 = continue, 1.0 = terminal).
- Type:
list or np.ndarray
- info#
Additional episode metadata.
- Type:
dict
- Parameters:
postprocess_fn (callable, optional) – Function to apply to observations before storing (e.g., normalization). Default: identity function.
Example:
episode = Episode() episode.append(obs, action, reward, False) episode.append(obs, action, reward, True) episode.terminate(final_obs) print(episode.x.shape) # Now a numpy array
- property size: int#
- class world_models.memory.planet_memory.Memory(size=None)[source]#
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Stores episodes as variable-length trajectories and supports sampling sub-sequences for training. Implements a ring-buffer style eviction when capacity is reached.
Stores complete episodes as lists of transitions
Samples contiguous sub-sequences for sequence models
Supports time-major formatting (time-first) for RNN input
Memory usage estimation to prevent OOM errors
- Parameters:
size (int, optional) – Maximum number of episodes to store. If None, deque grows without limit (useful for unpickling).
- episodes#
Collection of Episode objects.
- Type:
deque
- eps_lengths#
Length of each episode.
- Type:
deque
- size#
Total number of transitions across all episodes.
- Type:
property
Example:
memory = Memory(size=100) memory.append([episode1, episode2]) batch, lengths = memory.sample(batch_size=32, tracelen=50)
- property size: int#
- sample(batch_size, tracelen=1, time_first=False)[source]#
Sample random sub-sequences from stored episodes.
Randomly selects episodes and starting positions to create batches of contiguous sequences for training sequence models.
- Parameters:
batch_size (int) – Number of sequences to sample.
tracelen (int) – Length of each sequence (default: 1).
time_first (bool) – If True, returns tensors with time dimension first (T, B, …) instead of batch first (B, T, …).
- Returns:
- (observations, actions, rewards, terminals, lengths)
observations: (batch, tracelen+1, *obs_shape) or (tracelen+1, batch, …)
actions: (batch, tracelen, action_dim) or (tracelen, batch, …)
rewards: (batch, tracelen) or (tracelen, batch)
terminals: (batch, tracelen) or (tracelen, batch)
lengths: (batch,) original episode lengths for each sample
- Return type:
tuple
- Raises:
ValueError – If memory is empty or no episodes meet minimum length.
MemoryError – If estimated memory usage exceeds 200 MiB threshold.
- class world_models.memory.iris_memory.IRISReplayBuffer(size, obs_shape, action_size, seq_len=20, batch_size=64)[source]#
Bases:
objectReplay buffer for IRIS (Imagined Rollouts with Implicit Successor) training.
Stores (observation, action, reward, terminal) tuples in a ring buffer and supports sampling contiguous sequences for world model training.
- Features:
Ring buffer with fixed capacity (FIFO eviction when full)
Stores uint8 images for memory efficiency
Samples sequences with validation to avoid episode boundaries
Supports sequence sampling for temporal learning
- Memory Layout:
observations: (capacity, C, H, W) uint8
actions: (capacity, action_size) float32
rewards: (capacity,) float32
terminals: (capacity,) float32
- Parameters:
size (int) – Maximum number of transitions to store.
obs_shape (tuple) – Shape of observations as (C, H, W).
action_size (int) – Dimension of actions.
seq_len (int) – Length of sequences to sample (default: 20).
batch_size (int) – Number of sequences per batch (default: 64).
- size#
Buffer capacity.
- Type:
int
- obs_shape#
Observation shape.
- Type:
tuple
- action_size#
Action dimension.
- Type:
int
- seq_len#
Sequence length.
- Type:
int
- batch_size#
Batch size.
- Type:
int
- steps#
Total transitions added.
- Type:
int
- episodes#
Number of episode terminations observed.
- Type:
int
- add(obs, action, reward, terminal)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray) – Observation array with shape (C, H, W).
action (ndarray) – Action array with shape (action_size,).
reward (float) – Scalar reward value.
terminal (bool) – Boolean indicating if episode terminated.
- Return type:
None
- sample_sequence(seq_len=None)[source]#
Sample a batch of sequences for world model training.
- Returns:
(batch_size, seq_len+1, C, H, W) actions: (batch_size, seq_len, action_size) rewards: (batch_size, seq_len) terminals: (batch_size, seq_len)
- Return type:
- Parameters:
seq_len (int | None)
- sample_single()[source]#
Sample a single transition for online updates.
- Return type:
Tuple[ndarray, ndarray, float, float]
- property buffer_capacity: int#
Returns the total capacity of the buffer.
- class world_models.memory.iris_memory.IRISOnPolicyBuffer(max_steps=1000)[source]#
Bases:
objectOn-policy buffer for collecting trajectories during environment interaction.
Used to store the current episode data before adding to the main replay buffer. Unlike the main replay buffer, this collects trajectories in a list-based structure that’s cleared after each episode.
- Useful for:
Collecting complete episode trajectories
Storing data before batch processing
Temporary storage during environment interaction
- Parameters:
max_steps (int) – Maximum number of steps to store (default: 1000).
- max_steps#
Maximum buffer capacity.
- Type:
int
- observations#
List of observations.
- Type:
list
- actions#
List of actions.
- Type:
list
- rewards#
List of rewards.
- Type:
list
- terminals#
List of terminal flags.
- Type:
list
RSSM-based policy for model-predictive control.
This module provides the RSSMPolicy class that implements model-predictive control using the RSSM (Recurrent State Space Model) latent dynamics model. The policy uses a Cross-Entropy Method (CEM) for planning actions in latent space.
- Reference:
Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111
- class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#
Bases:
objectModel-predictive controller using Cross-Entropy Method (CEM) with RSSM.
Plans actions by optimizing a sequence of future actions in the RSSM’s latent space. Uses Cross-Entropy Method to refine action sequences based on predicted returns.
The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.
- Algorithm:
Initialize Gaussian distribution over action sequences
Sample N candidate action sequences
Rollout each sequence in RSSM latent space
Score by predicted cumulative rewards
Keep top K candidates, fit Gaussian to them
Repeat for T iterations
Execute first action from best sequence
- Parameters:
model (Any)
planning_horizon (int)
num_candidates (int)
num_iterations (int)
top_candidates (int)
device (device | str)
- rssm#
The RSSM world model.
- N#
Number of candidate action sequences to sample.
- K#
Number of top candidates to use for updating the proposal.
- T#
Number of CEM iterations per planning step.
- H#
Planning horizon (number of future steps to consider).
- d#
Action dimensionality.
- device#
Device to run computations on.
- state_size#
Hidden state dimensionality.
- latent_size#
Latent state dimensionality.
Example
>>> policy = RSSMPolicy( ... model=rssm, ... planning_horizon=12, ... num_candidates=1000, ... num_iterations=5, ... top_candidates=100, ... device='cuda' ... ) >>> policy.reset() >>> action = policy.poll(observation)
- reset()[source]#
Reset the policy state.
Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.
- Return type:
None
- poll(observation, explore=False)[source]#
Get action for given observation.
- Parameters:
observation (Tensor) – Current observation tensor of shape (channels, height, width).
explore (bool) – If True, add exploration noise to the selected action.
- Returns:
Action tensor of shape (1, action_size).
- Return type:
Tensor
- class world_models.controller.iris_policy.IRISActor(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleActor network for IRIS (Imagined Rollouts with Implicit Successor) policy.
Takes reconstructed frames as input and outputs action logits for policy control. Uses a CNN feature extractor followed by an LSTM for temporal processing. Supports a burn-in mechanism for initializing the hidden state with context frames.
- Architecture:
CNN: Extracts features from input frames (3x64x64 -> 512)
LSTM: Processes temporal sequences with configurable layers
Linear: Maps hidden states to action logits
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
- action_size#
Number of discrete actions.
- Type:
int
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- forward(frames, hidden_state=None, burn_in_frames=None)[source]#
Forward pass through actor.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W) or (B, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple for LSTM state
burn_in_frames (Tensor | None) – Frames to use for initializing hidden state
- Returns:
Action logits (B, T, action_size) or (B, action_size) hidden_state: Updated (h, c) tuple
- Return type:
action_logits
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- get_action(frame, temperature=1.0, deterministic=False)[source]#
Get action from a single frame.
- Parameters:
frame (Tensor) – Single frame (B, C, H, W)
temperature (float) – Softmax temperature (higher = more random)
deterministic (bool) – If True, return argmax; else sample
- Returns:
Selected action indices (B,)
- Return type:
action
- class world_models.controller.iris_policy.IRISCritic(hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCritic network for IRIS value estimation.
Estimates the value function for given frame sequences. Shares the CNN feature extractor and LSTM backbone with the actor for efficiency, but has a separate value head for estimating expected cumulative rewards.
- Architecture:
CNN: Shared feature extractor with actor (3x64x64 -> 512)
LSTM: Temporal processing with same architecture as actor
Linear: Maps hidden states to scalar values
- Parameters:
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
- Returns:
Value estimates with shape (B, T). hidden_state: Updated LSTM hidden state (h, c) tuple.
- Return type:
values
- Parameters:
hidden_size (int)
num_layers (int)
frame_shape (Tuple[int, int, int])
- forward(frames, hidden_state=None)[source]#
Forward pass through critic.
- Parameters:
frames (Tensor) – Input frames (B, T, C, H, W)
hidden_state (Tuple[Tensor, Tensor] | None) – Optional (h, c) tuple
- Returns:
Value estimates (B, T) hidden_state: Updated (h, c) tuple
- Return type:
values
Initialize LSTM hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
Tuple[Tensor, Tensor]
- class world_models.controller.iris_policy.CNNFeatureExtractor(frame_shape=(3, 64, 64), output_size=512)[source]#
Bases:
ModuleCNN feature extractor shared between actor and critic networks.
Processes input frames through a series of convolutional layers to produce fixed-size feature vectors. Architecture: Conv2d(3->32) -> ReLU -> stride2 repeated 4 times, followed by a linear projection to output_size.
- Architecture:
Conv layers: 32 -> 64 -> 128 -> 256 channels
Each conv has stride=2 for spatial downsampling
Final linear layer projects to desired output dimension
- Parameters:
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
output_size (int) – Size of output feature vector (default: 512).
- frame_shape#
Input frame shape.
- Type:
tuple
- output_size#
Output feature dimension.
- Type:
int
- Returns:
Feature vectors with shape (B, output_size).
- Return type:
features
- Parameters:
frame_shape (Tuple[int, int, int])
output_size (int)
- class world_models.controller.iris_policy.IRISPolicy(action_size, hidden_size=512, num_layers=4, frame_shape=(3, 64, 64))[source]#
Bases:
ModuleCombined policy module for IRIS (Imagined Rollouts with Implicit Successor).
Provides a unified interface for actor-only or actor-critic policies. Used in the IRIS algorithm where the actor generates actions from reconstructed frames and the critic estimates value functions for training.
- Parameters:
action_size (int) – Number of discrete actions.
hidden_size (int) – LSTM hidden state size (default: 512).
num_layers (int) – Number of LSTM layers (default: 4).
frame_shape (tuple) – Shape of input frames as (C, H, W) (default: (3, 64, 64)).
LSTM hidden state size.
- Type:
int
- num_layers#
Number of LSTM layers.
- Type:
int
- frame_shape#
Input frame shape.
- Type:
tuple
Example
>>> policy = IRISPolicy( ... action_size=18, ... hidden_size=512, ... num_layers=4, ... frame_shape=(3, 64, 64) ... ) >>> action = policy.act(frame, temperature=1.0, deterministic=False)
- forward(frames)[source]#
Get action logits from frames.
- Parameters:
frames (Tensor)
- Return type:
Tensor
- act(frame, temperature=1.0, deterministic=False)[source]#
Sample action from policy.
- Parameters:
frame (Tensor)
temperature (float)
deterministic (bool)
- Return type:
Tensor
Initialize hidden state.
- Parameters:
batch_size (int)
device (device)
- Return type:
tuple[Tensor, Tensor]
Rollout generation utilities for World Models.
This module provides the RolloutGenerator class for collecting episode experience using trained policies in environments.
- class world_models.controller.rollout_generator.RolloutGenerator(env, device, policy=None, max_episode_steps=None, episode_gen=None, name='', enable_streaming_video=False, streaming_video_path=None, streaming_video_fps=20, streaming_video_format='mp4')[source]#
Bases:
objectGenerator for collecting environment rollouts.
This class handles environment interactions and rollout collection, supporting both random and policy-based action selection.
- Parameters:
env (Any)
device (device | str)
policy (Any)
max_episode_steps (int | None)
episode_gen (Any)
name (str)
enable_streaming_video (bool)
streaming_video_path (str | None)
streaming_video_fps (int)
streaming_video_format (str)
- env#
The environment to interact with.
- device#
Device to run computations on.
- policy#
The policy to use for action selection (optional).
- episode_gen#
Factory for creating episode objects.
- name#
Name identifier for the generator.
- max_episode_steps#
Maximum steps per episode.
Example
>>> generator = RolloutGenerator( ... env=env, ... device='cuda', ... policy=policy, ... max_episode_steps=1000 ... ) >>> episode = generator.rollout_once()
- rollout_once(random_policy=False, explore=False)[source]#
Perform a single rollout of the environment.
- Parameters:
random_policy (bool) – If True, use random actions instead of policy.
explore (bool) – If True, add exploration noise to policy actions.
- Returns:
Episode object containing the rollout experience.
- Return type:
- rollout_n(n=1, random_policy=False)[source]#
Perform multiple rollouts.
- Parameters:
n (int) – Number of rollouts to perform.
random_policy (bool) – If True, use random actions.
- Returns:
List of Episode objects.
- Return type:
list
- class world_models.inference.operators.OperatorABC(*, device=None)[source]
Bases:
Module,ABCStructured base class for inference operators.
Operators use a consistent pipeline:
preprocessconverts raw inputs into tensors.forwardperforms model/operator-specific tensor computation.postprocessformats the final output mapping.
Subclasses may also declare
input_specsandoutput_specsto validate required tensor keys, shapes, and dtypes.OperatorABCinherits fromtorch.nn.Module, so operators supportto(device),train(), andeval()just like model modules.- Parameters:
device (torch.device | str | None)
- input_specs: Mapping[str, TensorSpec] = {}
- output_specs: Mapping[str, TensorSpec] = {}
- abstractmethod preprocess(inputs)[source]
Convert raw inputs into a tensor mapping ready for
forward.- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- forward(inputs)[source]
Run tensor computation for this operator.
Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.
- Parameters:
inputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- postprocess(outputs)[source]
Format validated forward outputs for consumers.
- Parameters:
outputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- process(inputs)[source]
Process raw inputs through preprocess, forward, and postprocess stages.
- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- batch(inputs)[source]
Preprocess a sequence of inputs and stack matching tensor keys.
- Parameters:
inputs (Sequence[Any])
- Return type:
dict[str, Tensor]
- to(*args, **kwargs)[source]
Move module parameters/buffers and remember the target tensor device.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
- classmethod validate_mapping(values, specs, *, label)[source]
Validate tensor keys, shapes, and dtypes against optional specs.
- Parameters:
values (Mapping[str, Tensor])
specs (Mapping[str, TensorSpec])
label (str)
- Return type:
None
- class world_models.inference.operators.TensorSpec(shape=None, dtype=None, required=True)[source]
Bases:
objectOptional tensor contract used to validate operator inputs or outputs.
- Parameters:
shape (tuple[int | None, ...] | None) – Expected shape. Use
Noneas a wildcard for dimensions that may vary, such as batch size.dtype (dtype | None) – Expected tensor dtype.
required (bool) – Whether the key must be present in the mapping being validated.
- shape: tuple[int | None, ...] | None = None
- dtype: dtype | None = None
- required: bool = True
- class world_models.inference.operators.DreamerOperator(image_size=64, action_dim=6)[source]
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- preprocess(inputs)[source]
Process Dreamer inputs: image observation and action.
Expected inputs: {‘image’: PIL.Image or tensor, ‘action’: tensor or list}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- class world_models.inference.operators.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
- preprocess(inputs)[source]
Process JEPA inputs: images with masking.
Expected inputs: {‘images’: list of PIL Images or tensors, ‘mask’: optional tensor}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- class world_models.inference.operators.IrisOperator(seq_length=512, vocab_size=32000)[source]
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- preprocess(inputs)[source]
Process Iris inputs: token sequences and optional embeddings.
Expected inputs: {‘tokens’: list of ints or tensor, ‘embeddings’: optional tensor}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- class world_models.inference.operators.PlaNetOperator(state_dim=32, action_dim=4)[source]
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- preprocess(inputs)[source]
Process PlaNet inputs: state observations and actions.
Expected inputs: {‘obs’: tensor or image, ‘action’: tensor, ‘reward’: float, ‘done’: bool}
- Parameters:
inputs (Dict[str, Any])
- Return type:
Dict[str, Tensor]
- world_models.inference.operators.get_operator(name, **kwargs)[source]
Factory function to get inference operators by name.
- Parameters:
name (str) – One of ‘dreamer’, ‘jepa’, ‘iris’, ‘planet’
**kwargs (Any) – Operator-specific configuration
- Returns:
Configured OperatorABC instance
- Return type:
Example
>>> op = get_operator('dreamer', image_size=64, action_dim=6) >>> processed = op.process({'image': image, 'action': action})
- class world_models.inference.operators.base.TensorSpec(shape=None, dtype=None, required=True)[source]#
Bases:
objectOptional tensor contract used to validate operator inputs or outputs.
- Parameters:
shape (tuple[int | None, ...] | None) – Expected shape. Use
Noneas a wildcard for dimensions that may vary, such as batch size.dtype (dtype | None) – Expected tensor dtype.
required (bool) – Whether the key must be present in the mapping being validated.
- shape: tuple[int | None, ...] | None = None#
- dtype: dtype | None = None#
- required: bool = True#
- class world_models.inference.operators.base.OperatorABC(*, device=None)[source]#
Bases:
Module,ABCStructured base class for inference operators.
Operators use a consistent pipeline:
preprocessconverts raw inputs into tensors.forwardperforms model/operator-specific tensor computation.postprocessformats the final output mapping.
Subclasses may also declare
input_specsandoutput_specsto validate required tensor keys, shapes, and dtypes.OperatorABCinherits fromtorch.nn.Module, so operators supportto(device),train(), andeval()just like model modules.- Parameters:
device (torch.device | str | None)
- input_specs: Mapping[str, TensorSpec] = {}#
- output_specs: Mapping[str, TensorSpec] = {}#
- abstractmethod preprocess(inputs)[source]#
Convert raw inputs into a tensor mapping ready for
forward.- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- forward(inputs)[source]#
Run tensor computation for this operator.
Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method.
- Parameters:
inputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- postprocess(outputs)[source]#
Format validated forward outputs for consumers.
- Parameters:
outputs (dict[str, Tensor])
- Return type:
dict[str, Tensor]
- process(inputs)[source]#
Process raw inputs through preprocess, forward, and postprocess stages.
- Parameters:
inputs (Any)
- Return type:
dict[str, Tensor]
- batch(inputs)[source]#
Preprocess a sequence of inputs and stack matching tensor keys.
- Parameters:
inputs (Sequence[Any])
- Return type:
dict[str, Tensor]
- to(*args, **kwargs)[source]#
Move module parameters/buffers and remember the target tensor device.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
- classmethod validate_mapping(values, specs, *, label)[source]#
Validate tensor keys, shapes, and dtypes against optional specs.
- Parameters:
values (Mapping[str, Tensor])
specs (Mapping[str, TensorSpec])
label (str)
- Return type:
None
- class world_models.inference.operators.dreamer_operator.DreamerOperator(image_size=64, action_dim=6)[source]#
Bases:
OperatorABCOperator for Dreamer model preprocessing: normalizes observations and encodes actions.
- Parameters:
image_size (int)
action_dim (int)
- class world_models.inference.operators.planet_operator.PlaNetOperator(state_dim=32, action_dim=4)[source]#
Bases:
OperatorABCOperator for PlaNet model preprocessing: encodes environment states and transitions.
- Parameters:
state_dim (int)
action_dim (int)
- class world_models.inference.operators.iris_operator.IrisOperator(seq_length=512, vocab_size=32000)[source]#
Bases:
OperatorABCOperator for Iris transformer model: formats sequences and embeddings.
- Parameters:
seq_length (int)
vocab_size (int)
- class world_models.inference.operators.jepa_operator.JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)[source]#
Bases:
OperatorABCOperator for JEPA model preprocessing: handles image/video masking and patch processing.
- Parameters:
image_size (int)
patch_size (int)
mask_ratio (float)
Datasets, environments, and transforms#
Environment adapters#
The environment APIs below mirror the dedicated environment guide pages: DMC, DeepMind Lab, Gym/Gymnasium, Atari/ALE, Procgen, MuJoCo, Unity ML-Agents, and vectorization utilities. DIAMOND-style Atari support is intentionally not listed as an environment adapter because it is Atari preprocessing rather than a separate environment family.
- world_models.envs.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#
Create any Atari environment from Arcadic Learning Environment (ALE).
- Parameters:
env_id (str) – The id of the Atari environment to create.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The created Atari environment.
- Return type:
gym.Env
- world_models.envs.list_available_atari_envs()[source]#
Get a list of all available Atari environments in Arcadic Learning Environment (ALE).
- Returns:
List of available Atari environment IDs.
- Return type:
list[str]
- world_models.envs.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#
Create vectorized Atari environments using ALE’s native AtariVectorEnv.
- Parameters:
game (str) – The name of the Atari game (e.g., “pong”, “breakout”).
num_envs (int) – Number of parallel environments.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
seed (Optional[int]) – Random seed for reproducibility.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The vectorized Atari environment.
- Return type:
AtariVectorEnv
- class world_models.envs.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#
Bases:
objectNative MuJoCo environment adapter for pixel-based world-model training.
The adapter uses the low-level
mujocoPython package directly: models are compiled from MJCF XML strings/files or MJB binaries viamujoco.MjModel; simulation state lives inmujoco.MjData; actions are written todata.ctrl; and images are produced withmujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract:{"image": uint8[C, H, W]}.Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply
reward_fnandterminal_fncallbacks. By default, rewards are0.0and episodes terminate only through external wrappers such asTimeLimit.- Parameters:
xml_path (str | Path | None)
xml_string (str | None)
binary_path (str | Path | None)
assets (dict[str, bytes] | None)
seed (int)
size (tuple[int, int])
camera (str | int | None)
reward_fn (RewardFn | None)
terminal_fn (TerminalFn | None)
frame_skip (int)
reset_noise_scale (float)
default_control_range (tuple[float, float])
- property observation_space: Dict#
- property action_space: Box#
- world_models.envs.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create one MuJoCo image environment factory for tasks and MJCF/MJB models.
- Parameters:
model (str | Path | None) – Either a Gymnasium MuJoCo task id such as
"Humanoid-v4", an MJCF XML path/string, or an MJB binary path.backend (str) –
"auto"infers native vs Gymnasium task mode. Use"native"for MJCF/MJB,"gymnasium"for task ids, or"robotics"for Gymnasium Robotics registrations.seed (int) – Seed forwarded to the image wrapper.
size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.
gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.makein task-id mode. Extra**kwargsare also forwarded there.**kwargs (Any) – Native
MuJoCoImageEnvoptions for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.
- Returns:
A TorchWM image environment returning
{"image": uint8[C, H, W]}.- Return type:
- world_models.envs.make_mujoco_env_from_config(args, size)[source]#
Build a MuJoCo image environment from a DreamerConfig-like object.
- Parameters:
args (Any)
size (tuple[int, int])
- Return type:
Any
- world_models.envs.list_gymnasium_robotics_envs()[source]#
List all Gymnasium Robotics ids registered by the installed package.
Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.
- Return type:
list[str]
- world_models.envs.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create a TorchWM image wrapper for a Gymnasium Robotics environment.
- Parameters:
env (str) – Any environment id registered by
gymnasium-robotics.seed (int) – Seed forwarded to
GymImageEnv.size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode forwarded to
gymnasium.make.gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.make.**kwargs (Any) – Additional keyword arguments forwarded to
gymnasium.make.
- Returns:
A
GymImageEnvthat emits{"image": uint8[C, H, W]}observations.- Return type:
- world_models.envs.register_gymnasium_robotics_envs()[source]#
Import Gymnasium Robotics so its environments are registered with Gymnasium.
Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external
gymnasium-roboticspackage. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely ongymnasium.register_envs; this helper supports both paths.- Return type:
Any
- class world_models.envs.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#
Bases:
objectGym-like environment wrapper that always returns image observations.
This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.
- Features:
Supports environment IDs (string) and pre-built environment objects.
Synthesizes RGB images from vector observations for pixel-based training.
Exposes continuous action spaces mapped to [-1, 1] range.
Converts discrete actions to one-hot vectors.
Returns observations as dict {“image”: (C, H, W)} with uint8 values.
- Parameters:
env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
seed (int) – Random seed for environment reset (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
render_mode (str) – Render mode for environment (default: “rgb_array”).
- observation_space#
Dict space with “image” key containing (C, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode (default: 1000).
- property observation_space: Dict#
- property action_space: Box#
- property max_episode_steps: int#
- world_models.envs.make_gym_env(env, **kwargs)[source]#
Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.
- Parameters:
env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
**kwargs (Any) – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)
- Returns:
- A wrapper that always returns image observations in the
format {“image”: (C, H, W)} suitable for pixel-based world models.
- Return type:
- class world_models.envs.WorldModelEnv(world_model, observation_space, action_space, *, initial_observation=None, initial_state=None, reset_fn=None, transition_fn=None, reward_fn=None, terminal_fn=None, render_fn=None, action_transform_fn=None, max_episode_steps=None, render_mode=None, device=None, torch_actions=True, seed=None)[source]#
Bases:
EnvExpose a trained world model through the Gymnasium
EnvAPI.WorldModelEnvkeeps the current latent/model state and advances it with a transition callable or with a compatible method onworld_model. The wrapper returns Gymnasium-style(obs, info)fromresetand(obs, reward, terminated, truncated, info)fromstep, making learned model rollouts pluggable into RL libraries such as Stable-Baselines3, TorchRL, and CleanRL.- Parameters:
world_model (Any) – Trained model or lightweight adapter object used for simulated dynamics.
observation_space (gym.Space) – Gymnasium observation space emitted by the wrapper.
action_space (gym.Space) – Gymnasium action space accepted by the wrapper.
initial_observation (Any | None) – Optional observation returned when no reset callable provides one. Defaults to
observation_space.sample().initial_state (Any | None) – Optional latent/model state used at reset.
reset_fn (ResetFn | None) – Optional callable for resetting model state. Accepted return forms are
obs,(obs, info),(state, obs),(state, obs, info), or a mapping withstate/observation.transition_fn (TransitionFn | None) – Optional callable for one model step. If omitted, the wrapper tries common methods on
world_model:env_step,step,predict_step,predict,imagine_step,transition, then__call__.reward_fn (RewardFn | None) – Optional callable used when the transition output omits a reward.
terminal_fn (TerminalFn | None) – Optional callable used when the transition output omits a termination flag.
render_fn (RenderFn | None) – Optional callable used by
render.action_transform_fn (ActionTransformFn | None) – Optional callable that converts library actions into the format expected by the world model.
max_episode_steps (int | None) – Optional time limit. Reaching it sets
truncated.render_mode (str | None) – Optional Gymnasium render mode.
rgb_arrayis supported by default when observations contain image-like data.device (Any | None) – Device used for tensor actions when
torch_actions=True.torch_actions (bool) – Convert actions to
torch.Tensorbefore model calls.seed (int | None) – Optional RNG seed for observation/action spaces and NumPy.
- metadata = {'render_fps': 30, 'render_modes': ['rgb_array']}#
- property state: Any#
Current latent/model state tracked by the wrapper.
- reset(*, seed=None, options=None)[source]#
Reset the simulated rollout and return
(observation, info).- Parameters:
seed (int | None)
options (dict[str, Any] | None)
- Return type:
tuple[Any, dict[str, Any]]
- step(action)[source]#
Roll the learned model forward for one simulated environment step.
- Parameters:
action (Any)
- Return type:
tuple[Any, float, bool, bool, dict[str, Any]]
- world_models.envs.make_world_model_env(world_model, **kwargs)[source]#
Create a
WorldModelEnvfrom a trained model and spaces.- Parameters:
world_model (Any)
kwargs (Any)
- Return type:
- class world_models.envs.ProcgenImageEnv(env, seed=0, size=(64, 64), distribution_mode='easy', num_levels=0, start_level=None, action_n=15, **procgen_kwargs)[source]#
Bases:
objectAdapt Procgen’s vector API to TorchWM’s single-env image interface.
The upstream
procgen.ProcgenEnvAPI is vectorized, so this wrapper builds a one-environment vector and unwraps the leading batch dimension. Actions are exposed as a continuous one-hot-likeBox[-1, 1]with one element per discrete Procgen action, matching TorchWM’s other discrete image adapters.- Parameters:
env (str)
seed (int)
size (tuple[int, int])
distribution_mode (str)
num_levels (int)
start_level (int | None)
action_n (int)
procgen_kwargs (Any)
- property observation_space: Dict#
- property action_space: _ProcgenActionSpace#
- property max_episode_steps: int#
- step(action)[source]#
- Parameters:
action (Any)
- Return type:
tuple[dict[str, ndarray[tuple[Any, …], dtype[uint8]]], float, bool, dict[str, Any]]
- world_models.envs.list_procgen_envs()[source]#
Return the Procgen game names understood by
ProcgenImageEnv.- Return type:
list[str]
- world_models.envs.make_procgen_env(env, **kwargs)[source]#
Create a single-environment Procgen adapter.
- Parameters:
env (str) – Procgen game name or Gym-style id.
**kwargs (Any) – Options forwarded to
ProcgenImageEnv.
- Returns:
TorchWM-compatible image wrapper exposing
{"image": (3, H, W) uint8}observations and one-hot-like actions.- Return type:
- world_models.envs.normalize_procgen_env_name(env)[source]#
Normalize Procgen Gym ids and shorthand names to Procgen game names.
Accepted forms include
"coinrun","procgen-coinrun-v0", and"procgen:procgen-coinrun-v0".- Parameters:
env (str)
- Return type:
str
- class world_models.envs.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#
Bases:
objectGym-like wrapper for Unity ML-Agents environments.
Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.
- Features:
Supports single-agent control with continuous action spaces.
Returns observations as {“image”: (C, H, W)} with uint8 values.
Normalizes actions to [-1, 1] range.
Includes rendered frames in observations for visual policies.
- Parameters:
file_name (str) – Path to the Unity environment binary.
behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.
seed (int) – Random seed for environment (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
worker_id (int) – Worker ID for multi-environment setup (default: 0).
base_port (int) – Base port for Unity environment communication (default: 5005).
no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).
time_scale (float) – Simulation time scale multiplier (default: 20.0).
quality_level (int) – Graphics quality level 0-5 (default: 1).
max_episode_steps (int) – Maximum steps per episode (default: 1000).
- observation_space#
Dict space with “image” key containing (3, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode.
- Raises:
ValueError – If no behaviors found or action space is not continuous.
RuntimeError – If no agents available after reset.
- Parameters:
file_name (str)
behavior_name (str | None)
seed (int)
size (tuple[int, int])
worker_id (int)
base_port (int)
no_graphics (bool)
time_scale (float)
quality_level (int)
max_episode_steps (int)
- property observation_space: Dict#
- property action_space: Box#
- property max_episode_steps: int#
- world_models.envs.make_unity_mlagents_env(env_id=None, **kwargs)[source]#
Create a Unity ML-Agents environment wrapper.
Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.
- Parameters:
**kwargs (Any) – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).
env_id (str | None)
**kwargs
- Returns:
A Gym-compatible wrapper for Unity environments.
- Return type:
- class world_models.envs.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#
Bases:
objectGym-style adapter for DeepMind Control Suite tasks.
The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.
- Features:
Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)
Automatically handles special cases like “cup” -> “ball_in_cup”
Renders RGB images at configurable resolution
Returns observations as dict with both state vectors and images
- Parameters:
name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).
seed (int) – Random seed for environment initialization.
size (tuple) – Target image size as (height, width) (default: (64, 64)).
camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.
- observation_space#
Dict space with state keys and “image”.
- Type:
gym.spaces.Dict
- action_space#
Continuous action space from DMC spec.
- Type:
gym.spaces.Box
Example
>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64)) >>> obs = env.reset() >>> print(obs.keys()) # dict_keys(['position', 'velocity', 'image'])
- property observation_space: Dict#
- property action_space: Box#
- class world_models.envs.BraxImageEnv(env, seed=0, size=(64, 64), backend=None, episode_length=None, auto_reset=False, jit=True, suppress_warp_warnings=True, **env_kwargs)[source]#
Bases:
objectGym-like adapter for training TorchWM world models on Brax tasks.
Brax environments are functional JAX environments:
resetconsumes a PRNG key and returns a state, whilestepconsumes the previous state plus an action and returns the next state. This adapter stores the Brax state between calls and converts state observations into image observations compatible with pixel-based TorchWM agents such as Dreamer.If a Brax renderer is not available, vector observations are rendered as deterministic feature-band images so training code can still consume a pixel stream. The original vector observation is also exposed through
info["vector_observation"]afterstepfor diagnostics.- Parameters:
env (str | Any)
seed (int)
size (tuple[int, int])
backend (str | None)
episode_length (int | None)
auto_reset (bool)
jit (bool)
suppress_warp_warnings (bool)
env_kwargs (Any)
- property observation_space: Space#
- property action_space: Space#
- property max_episode_steps: int#
- world_models.envs.make_brax_env(env, **kwargs)[source]#
Create a TorchWM image wrapper for Brax environments.
- Parameters:
env (str | Any) – Brax environment name (for example,
"ant") or a pre-built Brax environment object exposingreset(rng)andstep(state, action).**kwargs (Any) – Additional keyword arguments passed to
BraxImageEnv.
- Returns:
A Gym-like wrapper that returns
{"image": (C, H, W)}observations and exposes continuous actions in the Brax[-1, 1]range.- Return type:
- class world_models.envs.DMLabEnv(level, seed=0, size=(64, 64), action_repeat=4, action_set=None, observations=None, config=None, renderer='hardware', **lab_kwargs)[source]#
Bases:
objectGym-style adapter for DeepMind Lab 3D environments.
The native
deepmind_labAPI exposes RGB observations as HWC arrays and expects a seven-element integer action vector. This adapter presents a TorchWM-friendly image observation dict and a Box action space containing a one-hot vector in[-1, 1]so it composes with Dreamer’s normalization wrappers.- Parameters:
level (str)
seed (int)
size (tuple[int, int])
action_repeat (int)
action_set (Sequence[Sequence[int]] | np.ndarray | None)
observations (Sequence[str] | None)
config (dict[str, Any] | None)
renderer (str)
lab_kwargs (Any)
- property observation_space: Dict#
- property action_space: _OneHotActionSpace#
- property max_episode_steps: int#
- world_models.envs.make_dmlab_env(level, **kwargs)[source]#
Create a DeepMind Lab environment adapter for TorchWM.
- Parameters:
level (str) – DeepMind Lab level name, for example
"rooms_collect_good_objects_train".**kwargs (Any) – Additional keyword arguments passed to
DMLabEnv.
- Returns:
A Gym-like wrapper returning
{"image": (C, H, W)}uint8 observations and normalized one-hot discrete actions.- Return type:
- class world_models.envs.BSuiteImageEnv(bsuite_id, seed=0, size=(64, 64), env=None)[source]#
Bases:
objectGym-like wrapper for DeepMind BSuite
dm_envenvironments.BSuite tasks expose compact
dm_envobservations and mostly discrete actions. This adapter presents a Gym/Gymnasium-style API with image observations under{"image": (C, H, W)}so TorchWM’s pixel-based world models can train and evaluate on BSuite diagnostic tasks without requiring the base environment to implement rendering.- Parameters:
bsuite_id (str)
seed (int)
size (tuple[int, int])
env (Any | None)
- property observation_space: Dict#
- property action_space: Space#
- property max_episode_steps: int#
- world_models.envs.make_bsuite_env(bsuite_id, **kwargs)[source]#
Create a Dreamer-compatible image wrapper around a BSuite task.
- Parameters:
bsuite_id (str)
kwargs (Any)
- Return type:
- world_models.envs.list_available_bsuite_ids()[source]#
Return the installed BSuite sweep ids, or examples if BSuite is absent.
- Return type:
list[str]
- class world_models.envs.TimeLimit(env, duration)[source]#
Bases:
objectTerminate episodes after a fixed number of wrapper steps.
If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.
- Parameters:
env (Any)
duration (int)
- class world_models.envs.ActionRepeat(env, amount)[source]#
Bases:
objectRepeat each action for a fixed number of environment steps.
Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.
- Parameters:
env (Any)
amount (int)
- class world_models.envs.NormalizeActions(env)[source]#
Bases:
objectExpose a normalized [-1, 1] action space for bounded continuous controls.
Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.
- Parameters:
env (Any)
- property action_space: Box#
- class world_models.envs.ObsDict(env, key='obs')[source]#
Bases:
objectConvert scalar/array observations into a dictionary observation format.
This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).
- Parameters:
env (Any)
key (str)
- property observation_space: Dict#
- property action_space: Any#
- class world_models.envs.OneHotAction(env)[source]#
Bases:
objectWrap discrete-action environments to accept one-hot action vectors.
The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.
- Parameters:
env (Any)
- property action_space: Box#
- class world_models.envs.RewardObs(env)[source]#
Bases:
objectAugment observations with the latest scalar reward under obs[“reward”].
Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.
- Parameters:
env (Any)
- property observation_space: Dict#
- class world_models.envs.ResizeImage(env, size=(64, 64))[source]#
Bases:
objectResize image-like observation entries to a target spatial size.
The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.
- Parameters:
env (Any)
size (tuple[int, int])
- property obs_space: dict[str, Any]#
- class world_models.envs.RenderImage(env, key='image')[source]#
Bases:
objectInject RGB renders from env.render(“rgb_array”) into observations.
This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.
- Parameters:
env (Any)
key (str)
- property obs_space: dict[str, Any]#
- class world_models.envs.SelectAction(env, key)[source]#
Bases:
WrapperGym wrapper for dictionary actions that forwards a selected key only.
This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.
- Parameters:
env (Any)
key (str)
- world_models.envs.make_env(env_id, **kwargs)[source]#
Compatibility helper: create an environment by delegating to package-specific factories when available, falling back to gym.make.
This preserves older callers that expect make_env to exist.
- Parameters:
env_id (str)
kwargs (Any)
- Return type:
Any
- class world_models.envs.dmc.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#
Bases:
objectGym-style adapter for DeepMind Control Suite tasks.
The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.
- Features:
Parses domain-task names (e.g., “cheetah-run” -> domain=”cheetah”, task=”run”)
Automatically handles special cases like “cup” -> “ball_in_cup”
Renders RGB images at configurable resolution
Returns observations as dict with both state vectors and images
- Parameters:
name (str) – Environment name in format “domain-task” (e.g., “cheetah-run”).
seed (int) – Random seed for environment initialization.
size (tuple) – Target image size as (height, width) (default: (64, 64)).
camera (int, optional) – Camera ID for rendering. Defaults to 0 for most domains, 2 for quadruped.
- observation_space#
Dict space with state keys and “image”.
- Type:
gym.spaces.Dict
- action_space#
Continuous action space from DMC spec.
- Type:
gym.spaces.Box
Example
>>> env = DeepMindControlEnv("cheetah-run", seed=0, size=(64, 64)) >>> obs = env.reset() >>> print(obs.keys()) # dict_keys(['position', 'velocity', 'image'])
- property observation_space: Dict#
- property action_space: Box#
- world_models.envs.dmlab.make_dmlab_env(level, **kwargs)[source]#
Create a DeepMind Lab environment adapter for TorchWM.
- Parameters:
level (str) – DeepMind Lab level name, for example
"rooms_collect_good_objects_train".**kwargs (Any) – Additional keyword arguments passed to
DMLabEnv.
- Returns:
A Gym-like wrapper returning
{"image": (C, H, W)}uint8 observations and normalized one-hot discrete actions.- Return type:
- class world_models.envs.dmlab.DMLabEnv(level, seed=0, size=(64, 64), action_repeat=4, action_set=None, observations=None, config=None, renderer='hardware', **lab_kwargs)[source]#
Bases:
objectGym-style adapter for DeepMind Lab 3D environments.
The native
deepmind_labAPI exposes RGB observations as HWC arrays and expects a seven-element integer action vector. This adapter presents a TorchWM-friendly image observation dict and a Box action space containing a one-hot vector in[-1, 1]so it composes with Dreamer’s normalization wrappers.- Parameters:
level (str)
seed (int)
size (tuple[int, int])
action_repeat (int)
action_set (Sequence[Sequence[int]] | np.ndarray | None)
observations (Sequence[str] | None)
config (dict[str, Any] | None)
renderer (str)
lab_kwargs (Any)
- property observation_space: Dict#
- property action_space: _OneHotActionSpace#
- property max_episode_steps: int#
- world_models.envs.gym_env.make_gym_env(env, **kwargs)[source]#
Create a GymImageEnv wrapper for generic Gym/Gymnasium environments.
- Parameters:
env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
**kwargs (Any) – Additional keyword arguments passed to GymImageEnv, including: - seed (int): Random seed for environment (default: 0) - size (tuple): Target image size as (height, width) (default: (64, 64)) - render_mode (str): Render mode for environment (default: “rgb_array”)
- Returns:
- A wrapper that always returns image observations in the
format {“image”: (C, H, W)} suitable for pixel-based world models.
- Return type:
- class world_models.envs.gym_env.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#
Bases:
objectGym-like environment wrapper that always returns image observations.
This wrapper normalizes diverse environment interfaces to return consistent image-based observations suitable for pixel-based world models like Dreamer.
- Features:
Supports environment IDs (string) and pre-built environment objects.
Synthesizes RGB images from vector observations for pixel-based training.
Exposes continuous action spaces mapped to [-1, 1] range.
Converts discrete actions to one-hot vectors.
Returns observations as dict {“image”: (C, H, W)} with uint8 values.
- Parameters:
env (Any) – Either a string environment ID (e.g., “Pendulum-v1”) or a pre-built gym environment instance.
seed (int) – Random seed for environment reset (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
render_mode (str) – Render mode for environment (default: “rgb_array”).
- observation_space#
Dict space with “image” key containing (C, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode (default: 1000).
- property observation_space: Dict#
- property action_space: Box#
- property max_episode_steps: int#
- world_models.envs.ale_atari_env.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#
Create any Atari environment from Arcadic Learning Environment (ALE).
- Parameters:
env_id (str) – The id of the Atari environment to create.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The created Atari environment.
- Return type:
gym.Env
- world_models.envs.ale_atari_env.list_available_atari_envs()[source]#
Get a list of all available Atari environments in Arcadic Learning Environment (ALE).
- Returns:
List of available Atari environment IDs.
- Return type:
list[str]
- world_models.envs.ale_atari_vector_env.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#
Create vectorized Atari environments using ALE’s native AtariVectorEnv.
- Parameters:
game (str) – The name of the Atari game (e.g., “pong”, “breakout”).
num_envs (int) – Number of parallel environments.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
seed (Optional[int]) – Random seed for reproducibility.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The vectorized Atari environment.
- Return type:
AtariVectorEnv
Procgen environment adapter for TorchWM image-based agents.
- world_models.envs.procgen_env.list_procgen_envs()[source]#
Return the Procgen game names understood by
ProcgenImageEnv.- Return type:
list[str]
- world_models.envs.procgen_env.normalize_procgen_env_name(env)[source]#
Normalize Procgen Gym ids and shorthand names to Procgen game names.
Accepted forms include
"coinrun","procgen-coinrun-v0", and"procgen:procgen-coinrun-v0".- Parameters:
env (str)
- Return type:
str
- world_models.envs.procgen_env.make_procgen_env(env, **kwargs)[source]#
Create a single-environment Procgen adapter.
- Parameters:
env (str) – Procgen game name or Gym-style id.
**kwargs (Any) – Options forwarded to
ProcgenImageEnv.
- Returns:
TorchWM-compatible image wrapper exposing
{"image": (3, H, W) uint8}observations and one-hot-like actions.- Return type:
- class world_models.envs.procgen_env.ProcgenImageEnv(env, seed=0, size=(64, 64), distribution_mode='easy', num_levels=0, start_level=None, action_n=15, **procgen_kwargs)[source]#
Bases:
objectAdapt Procgen’s vector API to TorchWM’s single-env image interface.
The upstream
procgen.ProcgenEnvAPI is vectorized, so this wrapper builds a one-environment vector and unwraps the leading batch dimension. Actions are exposed as a continuous one-hot-likeBox[-1, 1]with one element per discrete Procgen action, matching TorchWM’s other discrete image adapters.- Parameters:
env (str)
seed (int)
size (tuple[int, int])
distribution_mode (str)
num_levels (int)
start_level (int | None)
action_n (int)
procgen_kwargs (Any)
- property observation_space: Dict#
- property action_space: _ProcgenActionSpace#
- property max_episode_steps: int#
- step(action)[source]#
- Parameters:
action (Any)
- Return type:
tuple[dict[str, ndarray[tuple[Any, …], dtype[uint8]]], float, bool, dict[str, Any]]
- world_models.envs.mujoco_env.make_mujoco_env_from_config(args, size)[source]#
Build a MuJoCo image environment from a DreamerConfig-like object.
- Parameters:
args (Any)
size (tuple[int, int])
- Return type:
Any
- class world_models.envs.mujoco_env.MuJoCoImageEnv(xml_path=None, *, xml_string=None, binary_path=None, assets=None, seed=0, size=(64, 64), camera=None, reward_fn=None, terminal_fn=None, frame_skip=1, reset_noise_scale=0.0, default_control_range=(-1.0, 1.0))[source]#
Bases:
objectNative MuJoCo environment adapter for pixel-based world-model training.
The adapter uses the low-level
mujocoPython package directly: models are compiled from MJCF XML strings/files or MJB binaries viamujoco.MjModel; simulation state lives inmujoco.MjData; actions are written todata.ctrl; and images are produced withmujoco.Renderer. Observations follow TorchWM’s Dreamer-style contract:{"image": uint8[C, H, W]}.Native MuJoCo models do not define task rewards or episode termination by themselves, so callers can supply
reward_fnandterminal_fncallbacks. By default, rewards are0.0and episodes terminate only through external wrappers such asTimeLimit.- Parameters:
xml_path (str | Path | None)
xml_string (str | None)
binary_path (str | Path | None)
assets (dict[str, bytes] | None)
seed (int)
size (tuple[int, int])
camera (str | int | None)
reward_fn (RewardFn | None)
terminal_fn (TerminalFn | None)
frame_skip (int)
reset_noise_scale (float)
default_control_range (tuple[float, float])
- property observation_space: Dict#
- property action_space: Box#
- world_models.envs.mujoco_env.make_mujoco_env(model=None, *, backend='auto', seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create one MuJoCo image environment factory for tasks and MJCF/MJB models.
- Parameters:
model (str | Path | None) – Either a Gymnasium MuJoCo task id such as
"Humanoid-v4", an MJCF XML path/string, or an MJB binary path.backend (str) –
"auto"infers native vs Gymnasium task mode. Use"native"for MJCF/MJB,"gymnasium"for task ids, or"robotics"for Gymnasium Robotics registrations.seed (int) – Seed forwarded to the image wrapper.
size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode used for Gymnasium MuJoCo task ids.
gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.makein task-id mode. Extra**kwargsare also forwarded there.**kwargs (Any) – Native
MuJoCoImageEnvoptions for MJCF/MJB mode, or environment-constructor options for Gymnasium task-id mode.
- Returns:
A TorchWM image environment returning
{"image": uint8[C, H, W]}.- Return type:
- world_models.envs.robotics_env.is_moved_mujoco_error(exc)[source]#
Return whether Gymnasium reported the v2/v3 MuJoCo move.
- Parameters:
exc (BaseException)
- Return type:
bool
- world_models.envs.robotics_env.register_gymnasium_robotics_envs()[source]#
Import Gymnasium Robotics so its environments are registered with Gymnasium.
Gymnasium moved legacy MuJoCo v2/v3 task registrations into the external
gymnasium-roboticspackage. Current Gymnasium Robotics versions register environments during import, while older plugin-style installations may rely ongymnasium.register_envs; this helper supports both paths.- Return type:
Any
- world_models.envs.robotics_env.list_gymnasium_robotics_envs()[source]#
List all Gymnasium Robotics ids registered by the installed package.
Returns an empty list when the optional dependency is not installed. When it is installed, the list is derived from Gymnasium’s registry rather than a hand-maintained subset, so newly added Robotics environments are exposed automatically.
- Return type:
list[str]
- world_models.envs.robotics_env.make_gymnasium_env_with_robotics_fallback(env, *, render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create a Gymnasium env and retry after Robotics registration if needed.
- Parameters:
env (str)
render_mode (str)
gym_kwargs (dict[str, Any] | None)
kwargs (Any)
- Return type:
Any
- world_models.envs.robotics_env.make_robotics_env(env, *, seed=0, size=(64, 64), render_mode='rgb_array', gym_kwargs=None, **kwargs)[source]#
Create a TorchWM image wrapper for a Gymnasium Robotics environment.
- Parameters:
env (str) – Any environment id registered by
gymnasium-robotics.seed (int) – Seed forwarded to
GymImageEnv.size (tuple[int, int]) – Target
(height, width)image size.render_mode (str) – Render mode forwarded to
gymnasium.make.gym_kwargs (dict[str, Any] | None) – Optional keyword arguments forwarded to
gymnasium.make.**kwargs (Any) – Additional keyword arguments forwarded to
gymnasium.make.
- Returns:
A
GymImageEnvthat emits{"image": uint8[C, H, W]}observations.- Return type:
- world_models.envs.unity_env.make_unity_mlagents_env(env_id=None, **kwargs)[source]#
Create a Unity ML-Agents environment wrapper.
Factory function that instantiates a UnityMLAgentsEnv with the provided keyword arguments. Suitable for integrating Unity-based environments with Dreamer-style world model pipelines.
- Parameters:
**kwargs (Any) – Keyword arguments passed to UnityMLAgentsEnv, including: - file_name (str): Path to the Unity environment binary. - behavior_name (str, optional): Name of the behavior to use. - seed (int): Random seed (default: 0). - size (tuple): Image size as (height, width) (default: (64, 64)). - worker_id (int): Worker ID for multi-environment setup (default: 0). - base_port (int): Base port for communication (default: 5005). - no_graphics (bool): Disable graphics rendering (default: True). - time_scale (float): Simulation time scale (default: 20.0). - quality_level (int): Graphics quality level (default: 1). - max_episode_steps (int): Max steps per episode (default: 1000).
env_id (str | None)
**kwargs
- Returns:
A Gym-compatible wrapper for Unity environments.
- Return type:
- class world_models.envs.unity_env.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#
Bases:
objectGym-like wrapper for Unity ML-Agents environments.
Provides a unified interface for Unity-based environments, converting observations to image format compatible with pixel-based world models.
- Features:
Supports single-agent control with continuous action spaces.
Returns observations as {“image”: (C, H, W)} with uint8 values.
Normalizes actions to [-1, 1] range.
Includes rendered frames in observations for visual policies.
- Parameters:
file_name (str) – Path to the Unity environment binary.
behavior_name (str, optional) – Name of the behavior to use. If None, uses the first available behavior.
seed (int) – Random seed for environment (default: 0).
size (tuple) – Target image size as (height, width) (default: (64, 64)).
worker_id (int) – Worker ID for multi-environment setup (default: 0).
base_port (int) – Base port for Unity environment communication (default: 5005).
no_graphics (bool) – Disable graphics rendering for faster simulation (default: True).
time_scale (float) – Simulation time scale multiplier (default: 20.0).
quality_level (int) – Graphics quality level 0-5 (default: 1).
max_episode_steps (int) – Maximum steps per episode (default: 1000).
- observation_space#
Dict space with “image” key containing (3, H, W) Box.
- action_space#
Box space with actions in [-1, 1] range.
- max_episode_steps#
Maximum steps per episode.
- Raises:
ValueError – If no behaviors found or action space is not continuous.
RuntimeError – If no agents available after reset.
- Parameters:
file_name (str)
behavior_name (str | None)
seed (int)
size (tuple[int, int])
worker_id (int)
base_port (int)
no_graphics (bool)
time_scale (float)
quality_level (int)
max_episode_steps (int)
- property observation_space: Dict#
- property action_space: Box#
- property max_episode_steps: int#
- class world_models.envs.vector_env.SimWorker(worker_id, env_factory, num_envs, command_queue, result_queue, seed=None)[source]#
Bases:
ProcessWorker process that manages a batch of environment instances. Handles batched stepping for parallel rollouts.
- Parameters:
worker_id (int)
env_factory (Callable)
num_envs (int)
command_queue (Queue)
result_queue (Queue)
seed (Optional[int])
- class world_models.envs.vector_env.VectorizedEnv(env_factory, num_workers=2, envs_per_worker=4, seed=None)[source]#
Bases:
ABCAbstract base class for vectorized environments. Manages multiple worker processes for parallel simulation.
- Parameters:
env_factory (Callable)
num_workers (int)
envs_per_worker (int)
seed (Optional[int])
- class world_models.envs.vector_env.TorchVectorizedEnv(*args, **kwargs)[source]#
Bases:
VectorizedEnvTorchWM-compatible vectorized environment. Returns batched tensors suitable for PyTorch training.
- Parameters:
args (Any)
kwargs (Any)
- class world_models.envs.wrappers.TimeLimit(env, duration)[source]#
Bases:
objectTerminate episodes after a fixed number of wrapper steps.
If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.
- Parameters:
env (Any)
duration (int)
- class world_models.envs.wrappers.ActionRepeat(env, amount)[source]#
Bases:
objectRepeat each action for a fixed number of environment steps.
Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.
- Parameters:
env (Any)
amount (int)
- class world_models.envs.wrappers.NormalizeActions(env)[source]#
Bases:
objectExpose a normalized [-1, 1] action space for bounded continuous controls.
Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.
- Parameters:
env (Any)
- property action_space: Box#
- class world_models.envs.wrappers.ObsDict(env, key='obs')[source]#
Bases:
objectConvert scalar/array observations into a dictionary observation format.
This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).
- Parameters:
env (Any)
key (str)
- property observation_space: Dict#
- property action_space: Any#
- class world_models.envs.wrappers.OneHotAction(env)[source]#
Bases:
objectWrap discrete-action environments to accept one-hot action vectors.
The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.
- Parameters:
env (Any)
- property action_space: Box#
- class world_models.envs.wrappers.RewardObs(env)[source]#
Bases:
objectAugment observations with the latest scalar reward under obs[“reward”].
Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.
- Parameters:
env (Any)
- property observation_space: Dict#
- class world_models.envs.wrappers.ResizeImage(env, size=(64, 64))[source]#
Bases:
objectResize image-like observation entries to a target spatial size.
The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.
- Parameters:
env (Any)
size (tuple[int, int])
- property obs_space: dict[str, Any]#
- class world_models.envs.wrappers.RenderImage(env, key='image')[source]#
Bases:
objectInject RGB renders from env.render(“rgb_array”) into observations.
This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.
- Parameters:
env (Any)
key (str)
- property obs_space: dict[str, Any]#
- class world_models.envs.wrappers.UUID(env)[source]#
Bases:
WrapperGym wrapper that tracks a unique run identifier per environment reset.
The ID combines timestamp and UUID and can be used to tag episodes or artifacts generated during data collection.
- Parameters:
env (Any)
- class world_models.envs.wrappers.SelectAction(env, key)[source]#
Bases:
WrapperGym wrapper for dictionary actions that forwards a selected key only.
This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.
- Parameters:
env (Any)
key (str)
Atari preprocessing helpers#
These helpers wrap Atari environments for specific training recipes. They are not separate environment families.
- class world_models.envs.diamond_atari.DiamondAtariWrapper(env, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64))[source]#
Bases:
WrapperAtari wrapper for DIAMOND following the paper specifications: - frameskip: number of frames to skip (default 4) - max_noop: maximum number of noop actions at reset (default 30) - terminate_on_life_loss: terminate episode when life is lost (default True) - reward_clip: clip rewards to [-1, 0, 1] (default True) - resize: resize observations to specified size (default 64x64)
- Parameters:
env (Env)
frameskip (int)
max_noop (int)
terminate_on_life_loss (bool)
reward_clip (bool)
resize (Tuple[int, int] | None)
- world_models.envs.diamond_atari.make_diamond_atari_env(game, frameskip=4, max_noop=30, terminate_on_life_loss=True, reward_clip=True, resize=(64, 64), seed=None)[source]#
Create a DIAMOND-compatible Atari environment.
- Parameters:
game (str) – Atari game name (e.g., “Breakout-v5”)
frameskip (int) – Number of frames to skip between actions
max_noop (int) – Maximum number of noop actions at reset
terminate_on_life_loss (bool) – Whether to terminate on life loss
reward_clip (bool) – Whether to clip rewards to [-1, 0, 1]
resize (Tuple[int, int]) – Target size for observations
seed (int | None) – Random seed
- Returns:
Configured Atari environment
- Return type:
Datasets and transforms#
Data generation and dataset classes for World Models.
This module provides utilities for generating rollout data from environments and PyTorch dataset classes for loading observation sequences.
- class world_models.datasets.wm_dataset.RolloutDataset(root, transform, train=True, buffer_size=1000, num_test_files=600)[source]#
Bases:
DatasetPyTorch Dataset for loading rollout data.
This dataset loads pre-collected rollout trajectories from disk, providing a buffer-based mechanism for efficient data loading. It supports train/test splits and custom transforms.
- Parameters:
root (str)
transform (Compose)
train (bool)
buffer_size (int)
num_test_files (int)
- root#
Root directory containing rollout .npz files.
- transform#
Albumentations transform to apply to observations.
- train#
If True, use training split; otherwise use test split.
- buffer_size#
Maximum number of files to keep in memory.
- num_test_files#
Number of files to use for test set.
Example
>>> transform = transforms.Compose([transforms.ToTensor()]) >>> dataset = RolloutDataset( ... root='data/carracing', ... transform=transform, ... train=True, ... buffer_size=100, ... ) >>> obs, action, reward, terminal = dataset[0]
- class world_models.datasets.wm_dataset.ObservationDataset(root, transform, train=True, buffer_size=1000, num_test_files=600)[source]#
Bases:
RolloutDatasetDataset for single observation samples (not sequences).
This dataset extends RolloutDataset to provide individual observations rather than sequences, suitable for VAE training.
Example
>>> dataset = ObservationDataset( ... root='data/carracing', ... transform=transform, ... train=True, ... ) >>> obs = dataset[0]
- Parameters:
root (str)
transform (Compose)
train (bool)
buffer_size (int)
num_test_files (int)
- class world_models.datasets.wm_dataset.SequenceDataset(root, transform, train, buffer_size, num_test_files, seq_len)[source]#
Bases:
RolloutDatasetDataset for sequential rollout data.
This dataset provides sequences of observations, actions, rewards, and terminal flags suitable for training recurrent models like MDRNN.
- Parameters:
root (str)
transform (Compose)
train (bool)
buffer_size (int)
num_test_files (int)
seq_len (int)
- seq_len#
Length of sequences to return.
Example
>>> dataset = SequenceDataset( ... root='data/carracing', ... transform=transform, ... train=True, ... seq_len=32, ... ) >>> obs, action, reward, terminal, next_obs = dataset[0]
- class world_models.datasets.wm_dataset.LatentSequenceDataset(latents_arr, actions, rewards, terminals, train, buffer_size, num_test_files, seq_len)[source]#
Bases:
DatasetDataset for pre-computed latent sequences.
This dataset uses pre-encoded latent representations instead of raw images, which significantly reduces memory usage during RNN training.
- Parameters:
latents_arr (ndarray)
actions (ndarray)
rewards (ndarray)
terminals (ndarray)
train (bool)
buffer_size (int)
num_test_files (int)
seq_len (int)
- class world_models.datasets.video_datasets.DatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True)[source]#
Bases:
objectBase configuration for datasets.
- Parameters:
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
pin_memory (bool)
shuffle (bool)
- num_frames: int = 16#
- image_size: int = 64#
- batch_size: int = 4#
- num_workers: int = 4#
- pin_memory: bool = True#
- shuffle: bool = True#
- class world_models.datasets.video_datasets.VideoDatasetBase(data_source, num_frames=16, image_size=64, transform=None, normalize=True)[source]#
Bases:
DatasetBase class for video datasets.
All video datasets should inherit from this class and implement the _load_video method.
- Parameters:
data_source (str | Path | List[str] | List[Path])
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
- data_source: str | Path | List[str] | List[Path]#
- video_paths: Sequence[Path | int]#
- class world_models.datasets.video_datasets.VideoFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.mp4', '.avi', '.mkv', '.webm', '.mov'), recursive=True)[source]#
Bases:
VideoDatasetBaseDataset that loads videos from a folder.
Supports common video formats: .mp4, .avi, .mkv, .webm
Usage:
dataset = VideoFolderDataset( data_source="/path/to/videos", num_frames=16, image_size=64 )
- Parameters:
data_source (str | Path | List[str] | List[Path])
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
extensions (Tuple[str, ...])
recursive (bool)
- class world_models.datasets.video_datasets.ImageFolderDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, extensions=('.jpg', '.jpeg', '.png', '.bmp'), image_sort_key=None)[source]#
Bases:
VideoDatasetBaseDataset that loads image sequences from folders.
Each subfolder is treated as a video sequence.
Usage:
dataset = ImageFolderDataset( data_source="/path/to/images", num_frames=16, image_size=64 )
- Parameters:
data_source (str | Path | List[str] | List[Path])
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
extensions (Tuple[str, ...])
image_sort_key (Callable | None)
- class world_models.datasets.video_datasets.NumPyDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key=None)[source]#
Bases:
VideoDatasetBaseDataset that loads videos from numpy files.
Supports .npy and .npz files.
Usage:
dataset = NumPyDataset( data_source="/path/to/videos.npy", num_frames=16, image_size=64 )
- Parameters:
data_source (str | Path)
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
key (str | None)
- class world_models.datasets.video_datasets.RLEnvironmentDataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, obs_key='observations')[source]#
Bases:
VideoDatasetBaseDataset for RL environment recordings.
Loads trajectories stored as: - .npz files with ‘observations’ and ‘actions’ keys - Directory with episode folders
Usage:
dataset = RLEnvironmentDataset( data_source="/path/to/rl_episodes", num_frames=16, image_size=64 )
- Parameters:
data_source (str | Path)
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
obs_key (str)
- class world_models.datasets.video_datasets.HDF5Dataset(data_source, num_frames=16, image_size=64, transform=None, normalize=True, key='videos', memmap=False)[source]#
Bases:
VideoDatasetBaseDataset that loads videos from HDF5 files.
Supports pre-processed video datasets stored in HDF5 format. Expected structure: HDF5 file with ‘videos’ dataset of shape (N, T, H, W, C) or (N, T, C, H, W).
Usage:
dataset = HDF5Dataset( data_source="/path/to/videos.h5", num_frames=16, image_size=64 )
- Parameters:
data_source (str | Path)
num_frames (int)
image_size (int)
transform (Callable | None)
normalize (bool)
key (str)
memmap (bool)
- world_models.datasets.video_datasets.create_video_dataloader(dataset_type, data_source, num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, pin_memory=True, **kwargs)[source]#
Factory function to create video dataloaders.
- Parameters:
dataset_type (str) – Type of dataset (“video_folder”, “image_folder”, “numpy”, “rl”)
data_source (str | Path | List[str]) – Path or list of paths to data
num_frames (int) – Number of frames per video
image_size (int) – Target image size (height and width)
batch_size (int) – Batch size for dataloader
num_workers (int) – Number of workers for data loading
shuffle (bool) – Whether to shuffle data
pin_memory (bool) – Whether to pin memory for faster GPU transfer
**kwargs (Any) – Additional arguments for specific dataset types
- Returns:
Tuple of (dataset, dataloader)
- Return type:
Tuple[Dataset, DataLoader]
Usage:
dataset, loader = create_video_dataloader( dataset_type="video_folder", data_source="/path/to/videos", num_frames=16, image_size=64, batch_size=4 )
- class world_models.datasets.video_datasets.VideoDatasetConfig(num_frames=16, image_size=64, batch_size=4, num_workers=4, pin_memory=True, shuffle=True, dataset_type='video_folder', data_source='', extensions=('.mp4', '.avi', '.mkv'), recursive=True, obs_key='observations')[source]#
Bases:
DatasetConfigConfiguration for video datasets.
- Parameters:
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
pin_memory (bool)
shuffle (bool)
dataset_type (str)
data_source (str)
extensions (Tuple[str, ...])
recursive (bool)
obs_key (str)
- dataset_type: str = 'video_folder'#
- data_source: str = ''#
- extensions: Tuple[str, ...] = ('.mp4', '.avi', '.mkv')#
- recursive: bool = True#
- obs_key: str = 'observations'#
- world_models.datasets.video_datasets.create_video_dataset_from_config(config)[source]#
Create video dataset and dataloader from config.
- Parameters:
config (VideoDatasetConfig)
- Return type:
Tuple[Dataset, DataLoader]
TinyWorlds Dataset Loaders
Loads pre-processed video datasets from HuggingFace for training Genie-style world models. Based on: AlmondGod/tinyworlds
Available datasets: - PICO_DOOM: Minimal Doom gameplay - PONG: Classic Pong - ZELDA: Zelda Ocarina of Time (2D) - POLE_POSITION: Racing game - SONIC: Sonic the Hedgehog
- class world_models.datasets.tinyworlds.TinyWorldsConfig(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, cache_dir=None, split='train')[source]#
Bases:
objectConfiguration for TinyWorlds datasets.
- Parameters:
dataset_name (str)
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
cache_dir (str | None)
split (str)
- dataset_name: str = 'SONIC'#
- num_frames: int = 16#
- image_size: int = 64#
- batch_size: int = 4#
- num_workers: int = 4#
- cache_dir: str | None = None#
- split: str = 'train'#
- class world_models.datasets.tinyworlds.TinyWorldsDataset(dataset_name='SONIC', num_frames=16, image_size=64, split='train', cache_dir=None, download=True, data_file=None)[source]#
Bases:
DatasetDataset for TinyWorlds game video data.
Loads pre-processed frames from HuggingFace datasets repository.
- Parameters:
dataset_name (str)
num_frames (int)
image_size (int)
split (str)
cache_dir (str | None)
download (bool)
data_file (str | None)
- DATASET_CONFIGS = {'PICO_DOOM': {'description': 'Minimal Doom gameplay', 'filename': 'picodoom_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'POLE_POSITION': {'description': 'Racing game', 'filename': 'pole_position_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'PONG': {'description': 'Classic Pong', 'filename': 'pong_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'SONIC': {'description': 'Sonic the Hedgehog', 'filename': 'sonic_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}, 'ZELDA': {'description': 'Zelda Ocarina of Time (2D)', 'filename': 'zelda_frames.h5', 'repo_id': 'AlmondGod/tinyworlds'}}#
- class world_models.datasets.tinyworlds.TinyWorldsDataLoader[source]#
Bases:
objectFactory class for creating TinyWorlds dataloaders.
- DATASET_NAMES = ['PICO_DOOM', 'PONG', 'ZELDA', 'POLE_POSITION', 'SONIC']#
- static create_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#
- Parameters:
dataset_name (str)
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
shuffle (bool)
cache_dir (str | None)
download (bool)
data_file (str | None)
- Return type:
Tuple[TinyWorldsDataset, DataLoader]
- world_models.datasets.tinyworlds.create_tinyworlds_dataloader(dataset_name='SONIC', num_frames=16, image_size=64, batch_size=4, num_workers=4, shuffle=True, cache_dir=None, download=True, data_file=None)[source]#
- Parameters:
dataset_name (str)
num_frames (int)
image_size (int)
batch_size (int)
num_workers (int)
shuffle (bool)
cache_dir (str | None)
download (bool)
data_file (str | None)
- Return type:
Tuple[TinyWorldsDataset, DataLoader]
- world_models.datasets.tinyworlds.download_all_datasets(cache_dir=None)[source]#
Download all available TinyWorlds datasets.
- Parameters:
cache_dir (str | None) – Directory to cache downloaded datasets
- Returns:
Dictionary mapping dataset names to local file paths
- Return type:
Dict[str, str | None]
- class world_models.datasets.diamond_dataset.ReplayBuffer(capacity=1000, obs_shape=(64, 64, 3), action_dim=1, device='cpu')[source]#
Bases:
objectReplay buffer for storing environment interactions. Stores (observation, action, reward, done, next_observation) tuples.
- Parameters:
capacity (int)
obs_shape (Tuple[int, int, int])
action_dim (int)
device (str)
- add(obs, action, reward, done, next_obs)[source]#
Add a transition to the buffer.
- Parameters:
obs (ndarray)
action (int)
reward (float)
done (bool)
next_obs (ndarray)
- Return type:
None
- sample(batch_size)[source]#
Sample a random batch of transitions.
- Parameters:
batch_size (int)
- Return type:
Dict[str, Tensor]
- sample_sequence(batch_size, sequence_length, burn_in=0)[source]#
Sample a sequence of transitions for training.
- Parameters:
batch_size (int) – Number of sequences to sample
sequence_length (int) – Total sequence length (burn_in + horizon)
burn_in (int) – Number of initial frames to use for conditioning
- Returns:
Dictionary with tensors of shape (batch_size, sequence_length, …)
- Return type:
Dict[str, Tensor]
- is_ready(min_size)[source]#
Check if buffer has enough samples.
- Parameters:
min_size (int)
- Return type:
bool
- class world_models.datasets.diamond_dataset.SequenceDataset(replay_buffer, sequence_length=5, burn_in=4)[source]#
Bases:
DatasetPyTorch Dataset for sampling sequences from the replay buffer. Used for training the diffusion world model.
- Parameters:
replay_buffer (ReplayBuffer)
sequence_length (int)
burn_in (int)
- world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]#
Create CIFAR-10 dataset and distributed dataloader.
Factory function that creates a CIFAR-10 dataset with the provided transforms and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or diffusion training pipelines.
- Parameters:
transform (Any) – Transforms to apply to images (e.g., RandomCrop, ColorJitter).
batch_size (int) – Number of samples per batch.
collator (callable, optional) – Custom collate function for batching (e.g., mask collator for JEPA).
pin_mem (bool) – Whether to pin memory for faster GPU transfer (default: True).
num_workers (int) – Number of data loading workers (default: 8).
world_size (int) – Number of distributed processes (default: 1).
rank (int) – Rank of current process in distributed setting (default: 0).
root_path (str, optional) – Path to store/load CIFAR-10 data.
drop_last (bool) – Whether to drop incomplete final batch (default: True).
train (bool) – Whether to load train or test split (default: True).
download (bool) – Whether to download dataset if not present (default: False).
- Returns:
- (dataset, dataloader, sampler)
dataset: torchvision.datasets.CIFAR10 instance
dataloader: torch.utils.data.DataLoader with distributed sampling
sampler: torch.utils.data.distributed.DistributedSampler
- Return type:
tuple
Example
>>> transform = make_transforms(crop_size=224) >>> dataset, loader, sampler = make_cifar10( ... transform=transform, ... batch_size=256, ... root_path="./data", ... download=True ... )
- world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]#
Build an ImageNet-1K dataset and dataloader with distributed sampling.
Factory function that creates an ImageNet dataset and returns a tuple of (dataset, dataloader, sampler) for use in JEPA or other self-supervised training pipelines.
- Supports:
Optional data staging from network storage to local scratch
Subset filtering via text file listing allowed image IDs
Distributed sampling for multi-GPU training
- Parameters:
transform (Any) – Transforms to apply to images.
batch_size (int) – Number of samples per batch.
collator (callable, optional) – Custom collate function (e.g., mask collator).
pin_mem (bool) – Whether to pin memory for GPU transfer (default: True).
num_workers (int) – Number of data loading workers (default: 8).
world_size (int) – Number of distributed processes (default: 1).
rank (int) – Rank of current process (default: 0).
root_path (str, optional) – Root path containing ImageNet data.
image_folder (str, optional) – Subfolder containing ImageNet data.
training (bool) – Load train or validation split (default: True).
copy_data (bool) – Copy data locally for faster loading (default: False).
drop_last (bool) – Drop incomplete final batch (default: True).
subset_file (str, optional) – Path to file listing allowed image IDs.
- Returns:
- (dataset, dataloader, sampler)
dataset: ImageNet dataset instance
dataloader: DataLoader with distributed sampling
sampler: DistributedSampler instance
- Return type:
tuple
- class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]#
Bases:
ImageFolderImageNet dataset wrapper with optional local copy/extract workflow.
Extends torchvision.datasets.ImageFolder to support data staging from network storage to local scratch space for faster multi-process training on cluster environments (e.g., SLURM).
- Features:
Optional data copying from network storage to local /scratch
Extracts tar archives automatically on first access
Supports train/validation splits
Optional target indexing for balanced sampling
- Parameters:
root (str)
image_folder (str)
tar_file (str)
transform (Any)
train (bool)
job_id (str | None)
local_rank (int | None)
copy_data (bool)
index_targets (bool)
- class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]#
Bases:
objectView over an ImageNet dataset filtered by an explicit image-id list.
The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.
- Parameters:
dataset (Any)
subset_file (str)
- filter_dataset_(subset_file)[source]#
Filter self.dataset to a subset
- Parameters:
subset_file (str)
- Return type:
None
- property classes: Any#
- world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]#
Copy and extract ImageNet archives to per-job local scratch storage.
In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.
- Parameters:
root (str)
suffix (str)
image_folder (str)
tar_file (str)
job_id (str | None)
local_rank (int | None)
- Return type:
str | None
- world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]#
Create an ImageFolder dataset loader for custom folder-structured datasets.
Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.
- Parameters:
transform (Any)
batch_size (int)
collator (Any)
pin_mem (bool)
num_workers (int)
world_size (int)
rank (int)
root_path (str | None)
image_folder (str | None)
drop_last (bool)
val_split (float | None)
- Return type:
Tuple[Dataset, DataLoader, DistributedSampler]
PyTorch Dataset for the NuPlan autonomous driving dataset.
Requires nuplan-devkit and a local copy of the NuPlan dataset.
Download from https://www.nuplan.org/nuplan and set NUPLAN_DATA_ROOT
to the extracted path (default: ~/nuplan/dataset).
- class world_models.datasets.nuplan.NuPlanSample(scenario_name, map_raster, ego_past, ego_future, agents_past, agents_future, agents_mask, agent_types, planning_target)[source]#
Bases:
objectA single training sample from the NuPlan dataset.
- Parameters:
scenario_name (str)
map_raster (Tensor)
ego_past (Tensor)
ego_future (Tensor)
agents_past (Tensor)
agents_future (Tensor)
agents_mask (Tensor)
agent_types (Tensor)
planning_target (Tensor)
- scenario_name: str#
- map_raster: Tensor#
- ego_past: Tensor#
- ego_future: Tensor#
- agents_past: Tensor#
- agents_future: Tensor#
- agents_mask: Tensor#
- agent_types: Tensor#
- planning_target: Tensor#
- class world_models.datasets.nuplan.NuPlanDataset(data_root=None, map_root=None, split='train', db_files=None, map_version='nuplan-maps-v1.0', planning_horizon=80, past_horizon=20, map_extent=(100.0, 100.0), map_resolution=0.1, max_agents=32, limit_scenarios=None)[source]#
Bases:
Dataset[NuPlanSample]PyTorch Dataset over NuPlan scenarios for world model training.
Each sample contains rasterised map tiles, ego and agent history, and future planning targets at 10 Hz.
- Parameters:
data_root (str | Path | None) – Path to the NuPlan dataset root. Defaults to
$NUPLAN_DATA_ROOT.map_root (str | Path | None) – Path to NuPlan map data. Defaults to
$NUPLAN_MAP_ROOT.split (str) –
"train","val", or"test". The mini split is used automatically whendata_root / "mini"exists.db_files (list[str] | None) – Explicit list of
.dbfiles. WhenNonethe builder auto-discovers files underdata_root / split.map_version (str) – Map version string, e.g.
"nuplan-maps-v1.0".planning_horizon (int) – Number of future steps at 10 Hz (default 80 = 8 s).
past_horizon (int) – Number of past steps at 10 Hz (default 20 = 2 s).
map_extent (Tuple[float, float]) – Raster crop half-extent in metres
(width, height).map_resolution (float) – Metres per pixel for the raster.
max_agents (int) – Maximum agents per sample; fewer are zero-padded.
limit_scenarios (int | None) – Cap on total scenarios (useful for prototyping).
- world_models.datasets.nuplan.make_nuplan_dataloader(data_root=None, split='train', batch_size=32, num_workers=4, **dataset_kwargs)[source]#
Create a NuPlan DataLoader.
- Parameters:
data_root (str | Path | None) – Root of the NuPlan dataset (default:
$NUPLAN_DATA_ROOT).split (str) – Dataset split.
batch_size (int) – Batch size.
num_workers (int) – Worker count for the DataLoader.
**dataset_kwargs (Any) – Extra arguments forwarded to
NuPlanDataset.
- Return type:
(dataset, dataloader)
- world_models.transforms.image.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]#
Compose image augmentations and normalization for vision model training.
Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.
- Parameters:
crop_size (int)
crop_scale (tuple[float, float])
color_jitter (float)
horizontal_flip (bool)
color_distortion (bool)
gaussian_blur (bool)
normalization (tuple[tuple[float, ...], tuple[float, ...]])
- Return type:
Any
Masking and JEPA helpers#
Masks sub-module - Masking strategies for JEPA and masked training.
This package provides various masking collator classes for generating encoder/predictor masks during masked representation learning.
- Usage:
from world_models.masks import MaskCollator, DefaultCollator collator = MaskCollator(input_size=(64, 64), patch_size=8)
- world_models.masks.MultiblockMaskCollator#
alias of
MaskCollator
- world_models.masks.RandomMaskCollator#
alias of
MaskCollator
- class world_models.masks.DefaultCollator[source]#
Bases:
objectSimple collator that returns batch data and no masking metadata.
This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.
- class world_models.masks.default.DefaultCollator[source]#
Bases:
objectSimple collator that returns batch data and no masking metadata.
This is used when training code expects the JEPA-style collator return shape (batch, masks_enc, masks_pred) but masking is disabled.
- class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]#
Bases:
objectGenerate multi-block encoder and predictor masks for JEPA training.
For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.
- Parameters:
input_size (tuple[int, int])
patch_size (int)
enc_mask_scale (tuple[float, float])
pred_mask_scale (tuple[float, float])
aspect_ratio (tuple[float, float])
nenc (int)
npred (int)
min_keep (int)
allow_overlap (bool)
- class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]#
Bases:
objectGenerate random context/prediction patch splits for masked training.
A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.
- Parameters:
ratio (tuple)
input_size (tuple)
patch_size (int)
- world_models.helpers.jepa_helper.load_checkpoint(device, r_path, encoder, predictor, target_encoder, opt, scaler)[source]#
Load JEPA training state from disk into model and optimizer objects.
Restores encoder, predictor, optional target encoder, optimizer state, and optional AMP scaler, returning the resumed epoch for training restart.
- Parameters:
device (device)
r_path (str)
encoder (Module)
predictor (Module)
target_encoder (Module | None)
opt (Optimizer)
scaler (Any | None)
- Return type:
tuple
- world_models.helpers.jepa_helper.init_model(device, patch_size=16, model_name='vit_base', crop_size=224, pred_depth=6, pred_emb_dim=384)[source]#
Initialize JEPA encoder and predictor modules with ViT backbones.
Applies truncated-normal parameter initialization, moves modules to the requested device, and returns (encoder, predictor).
- Parameters:
device (device)
patch_size (int)
model_name (str)
crop_size (int)
pred_depth (int)
pred_emb_dim (int)
- Return type:
tuple
- world_models.helpers.jepa_helper.init_opt(encoder, predictor, iterations_per_epoch, start_lr, ref_lr, warmup, num_epochs, wd=1e-06, final_wd=1e-06, final_lr=0.0, use_bfloat16=False, ipe_scale=1.25)[source]#
Build optimizer, AMP scaler, LR scheduler, and weight-decay scheduler for JEPA.
Parameters are grouped to exclude bias/norm tensors from weight decay, matching typical transformer training best practices.
- Parameters:
encoder (Module)
predictor (Module)
iterations_per_epoch (int)
start_lr (float)
ref_lr (float)
warmup (float)
num_epochs (int)
wd (float)
final_wd (float)
final_lr (float)
use_bfloat16 (bool)
ipe_scale (float)
- Return type:
tuple
Benchmarks and reports#
Benchmarks sub-module - Benchmark runners and adapters for world models.
This package provides tools for running standardized evaluations of world models (Dreamer, IRIS, DIAMOND) across multiple seeds and computing aggregate metrics.
- Usage:
from world_models.benchmarks import BenchmarkRunner, DiamondAdapter runner = BenchmarkRunner(adapter_cls=DiamondAdapter, …)
- class world_models.benchmarks.runner.BenchmarkRunner(adapter_cls, out_dir='results')[source]#
Bases:
objectRun evaluations for adapters across seeds and export results.
- Usage:
runner = BenchmarkRunner(adapter_cls=adapters.DiamondAdapter) results = runner.run(games=[“Breakout-v5”], seeds=[0,1], episodes=5)
- Parameters:
adapter_cls (Callable[..., adapters.BaseAdapter])
out_dir (str)
- run(env_spec=None, seeds=None, num_episodes=5, checkpoint=None, extra_kwargs=None)[source]#
Run benchmark.
Returns a results dict with per-seed episode returns and computed metrics.
- Parameters:
env_spec (Any | None)
seeds (List[int] | None)
num_episodes (int)
checkpoint (str | None)
extra_kwargs (Dict[str, Any] | None)
- Return type:
Dict[str, Any]
- class world_models.benchmarks.runner.MultiAgentBenchmarkRunner(adapter_classes, out_dir='results')[source]#
Bases:
objectRun evaluations for multiple adapters on the same environment.
- Usage:
runner = MultiAgentBenchmarkRunner(adapters=[adapters.DiamondAdapter, adapters.IRISAdapter]) results = runner.run_all(game=”Breakout-v5”, seeds=[0,1], episodes=5)
- Parameters:
adapter_classes (List[type[adapters.BaseAdapter]])
out_dir (str)
- run_all(env_spec, seeds=None, num_episodes=5, checkpoints=None, extra_kwargs=None, train_epochs=None)[source]#
Run benchmarks for all adapters on the same environment.
Returns a results dict with results for each adapter.
- Parameters:
env_spec (Dict[str, Any])
seeds (List[int] | None)
num_episodes (int)
checkpoints (Dict[str, str] | None)
extra_kwargs (Dict[str, Any] | None)
train_epochs (int | None)
- Return type:
Dict[str, Any]
- class world_models.benchmarks.adapters.BaseAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
object- Parameters:
env_spec (Any | None)
seed (int)
kwargs (Any)
- class world_models.benchmarks.adapters.DiamondAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
kwargs (Any)
- class world_models.benchmarks.adapters.IRISAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
kwargs (Any)
- class world_models.benchmarks.adapters.DreamerAdapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
BaseAdapter- Parameters:
env_spec (Any | None)
seed (int)
kwargs (Any)
- class world_models.benchmarks.adapters.DreamerV1Adapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
DreamerAdapter- Parameters:
env_spec (Any | None)
seed (int)
kwargs (Any)
- class world_models.benchmarks.adapters.DreamerV2Adapter(env_spec=None, seed=0, **kwargs)[source]#
Bases:
DreamerAdapter- Parameters:
env_spec (Any | None)
seed (int)
kwargs (Any)
- world_models.benchmarks.metrics.compute_aggregate_metrics(per_seed_means)[source]#
- Parameters:
per_seed_means (Iterable[float])
- Return type:
Dict[str, float]
- world_models.benchmarks.metrics.bootstrap_ci(values, num_samples=1000, alpha=0.05)[source]#
Compute simple bootstrap 1-alpha CI on the mean.
- Parameters:
values (List[float])
num_samples (int)
alpha (float)
- Return type:
tuple[float, float]
- world_models.benchmarks.metrics.iqm_of_array(values)[source]#
Compute the Interquartile Mean (IQM) of an array of values.
IQM is the mean of values that lie between the 25th and 75th percentiles (inclusive). This is a robust central tendency measure used in RL benchmark reporting.
- Parameters:
values (Iterable[float])
- Return type:
float
- world_models.benchmarks.metrics.bootstrap_iqm_ci(values, num_samples=1000, alpha=0.05)[source]#
Bootstrap a confidence interval for the IQM.
Returns (lower, upper) percentiles of the bootstrap IQM distribution.
- Parameters:
values (List[float])
num_samples (int)
alpha (float)
- Return type:
tuple[float, float]
- world_models.benchmarks.reporting.export_csv(results, path)[source]#
- Parameters:
results (Dict[str, Any])
path (str)
- Return type:
None
Utilities#
Loss functions for World Models training.
This module provides loss functions for training VAE and other world model components.
- world_models.losses.convae_loss.conv_vae_loss_fn(reconst, x, mu, logsigma)[source]#
Compute the ConvVAE loss function.
The loss combines: 1. Reconstruction loss (MSE) between input and reconstructed images 2. KL divergence between learned latent distribution and prior (standard normal)
The total loss is: BCE + KLD
- Parameters:
reconst (Tensor) – Reconstructed images from the VAE decoder.
x (Tensor) – Original input images.
mu (Tensor) – Mean of the latent distribution.
logsigma (Tensor) – Log variance of the latent distribution.
- Returns:
Scalar tensor containing the total VAE loss.
- Return type:
Tensor
Example
>>> recon_x, mu, logsigma = vae(images) >>> loss = conv_vae_loss_fn(recon_x, images, mu, logsigma) >>> loss.backward()
Gaussian Mixture Model (GMM) loss for MDRNN training.
This module provides the GMM loss function used in the Mixture Density Recurrent Neural Network (MDRNN) for world model training.
- world_models.losses.gmm_loss.gmm_loss(latent_next_obs, mus, sigmas, logpi, reduce=True)[source]#
Compute the negative log-likelihood of a batch under a Gaussian Mixture Model.
This function computes minus the log probability of the batch under the GMM model described by mus, sigmas, and pi:
\[p(x) = \sum_k \pi_k \cdot \mathcal{N}(x \mid \mu_k, \sigma_k)\]This is the loss function used in the MDRNN paper for predicting the next latent state.
- Parameters:
latent_next_obs (Tensor) – (bs1, bs2, …, fs) Tensor containing the batch of target data.
mus (Tensor) – (bs1, bs2, …, gs, fs) Tensor of mixture means.
sigmas (Tensor) – (bs1, bs2, …, gs, fs) Tensor of mixture standard deviations.
logpi (Tensor) – (bs1, bs2, …, gs) Tensor of log mixture weights (log pi_k).
reduce (bool) – If True, mean over batch dimensions; otherwise return per-sample loss.
- Returns:
scalar tensor with mean negative log-likelihood. If reduce is False: tensor with per-sample negative log-likelihoods.
- Return type:
If reduce is True
- Reference:
Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution.
Example
>>> batch = torch.randn(32, 10) >>> mus = torch.randn(32, 10, 5, 10) >>> sigmas = torch.randn(32, 10, 5, 10).exp() >>> logpi = torch.randn(32, 10, 5).log_softmax(dim=-1) >>> loss = gmm_loss(batch, mus, sigmas, logpi)
Training utilities for World Models.
This module provides utility classes for training neural networks including early stopping and learning rate scheduling.
- class world_models.utils.train_utils.EarlyStopping(mode='min', patience=10, threshold=0.0001, threshold_mode='rel')[source]#
Bases:
objectEarly stopping handler to stop training when validation metric stops improving.
This class monitors a validation metric and stops training when no improvement is seen for a specified number of epochs (patience). This helps prevent overfitting and reduces unnecessary computation.
- Parameters:
mode (str) – One of ‘min’ or ‘max’. In ‘min’ mode, training stops when the metric stops decreasing; in ‘max’ mode, when it stops increasing.
patience (int) – Number of epochs with no improvement after which to stop training.
threshold (float) – Minimum change to qualify as an improvement.
threshold_mode (str) – One of ‘rel’ or ‘abs’. In ‘rel’ mode, dynamic threshold is relative to best value; in ‘abs’ mode, it’s absolute.
- stop#
Property that returns True if training should stop.
Example
>>> early_stopping = EarlyStopping(mode='min', patience=10) >>> for epoch in range(100): ... val_loss = validate() ... early_stopping.step(val_loss) ... if early_stopping.stop: ... print(f"Stopped at epoch {epoch}") ... break
- step(metrics, epoch=None)[source]#
Update early stopping state with new metric value.
- Parameters:
metrics (float) – Current epoch’s metric value.
epoch (int | None) – Current epoch number. If None, auto-increments from last epoch.
- Return type:
None
- property stop: bool#
True if training should stop due to no improvement.
- Type:
bool
- class world_models.utils.train_utils.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', min_lr=0, eps=1e-08)[source]#
Bases:
objectReduce learning rate when a metric stops improving.
This scheduler reduces the learning rate by a factor when a validation metric stops improving for a specified number of epochs. This helps models converge better by reducing the step size as they approach optimal weights.
- Parameters:
optimizer (Optimizer) – The PyTorch optimizer to adjust.
mode (str) – One of ‘min’ or ‘max’. In ‘min’ mode, lr is reduced when metric stops decreasing; in ‘max’ mode, when it stops increasing.
factor (float) – Factor by which to reduce the learning rate.
patience (int) – Number of epochs with no improvement after which to reduce lr.
threshold (float) – Minimum change to qualify as an improvement.
threshold_mode (str) – One of ‘rel’ or ‘abs’.
min_lr (float) – Minimum learning rate to reduce to.
eps (float) – Minimum decay for lr.
- lr#
Current learning rates for each parameter group.
Example
>>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) >>> for epoch in range(100): ... train_loss = train() ... val_loss = validate() ... scheduler.step(val_loss) ... if scheduler.stop: ... break
- step(metrics, epoch=None)[source]#
Update learning rate based on metric value.
- Parameters:
metrics (float) – Current epoch’s metric value.
epoch (int | None) – Current epoch number. If None, auto-increments from last epoch.
- Return type:
None
- property lr: list#
Current learning rates for each parameter group.
- Type:
list
- world_models.utils.dreamer_utils.symlog(x)[source]#
Symmetric log transform used by Dreamer V2 for reward/value targets.
Defined as
sign(x) * log(1 + |x|). This compresses large positive or negative values into a range that is easier to predict with a categorical distribution over a bounded set of buckets.- Parameters:
x (Tensor)
- Return type:
Tensor
- world_models.utils.dreamer_utils.symexp(x)[source]#
Inverse of
symlog().Defined as
sign(x) * (exp(|x|) - 1).- Parameters:
x (Tensor)
- Return type:
Tensor
- class world_models.utils.dreamer_utils.TwoHotEncoder(num_buckets=255, symlog_range=10.0)[source]#
Bases:
objectTwo-hot encoding for symlog targets (Dreamer V2 reward/value heads).
A target value is softly assigned to the two nearest buckets on a uniform grid spanning
[-symlog_range, symlog_range]. The categorical logits produced by a network can then be decoded back into a real value by computing the expected bucket center.- Parameters:
num_buckets (int) – Number of buckets in the categorical distribution.
symlog_range (float) – Maximum absolute value (in symlog space) covered by the grid. Values outside the range are clipped to the boundary buckets.
- register_buffers()[source]#
Allocate the bucket-center buffer on CPU. Use
to()to move.- Return type:
None
- encode(target)[source]#
Two-hot encode a real-valued target into soft bucket probabilities.
- Parameters:
target (Tensor) – Tensor of arbitrary shape containing real-valued targets.
- Returns:
Tensor with an extra final dimension of size
num_bucketscontaining the soft two-hot distribution. The encoding assumes the target is already in symlog space, matching Dreamer V2.- Return type:
Tensor
- decode(logits)[source]#
Decode categorical logits into the expected real-valued prediction.
The logits are first softmaxed and then combined with the bucket centers. The output is passed through
symexp()to invert the symlog transform.- Parameters:
logits (Tensor) – Tensor with a final dimension of
num_buckets.- Returns:
Tensor with the same shape as
logitsminus the last dimension.- Return type:
Tensor
- world_models.utils.dreamer_utils.get_parameters(modules)[source]#
Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters
- Parameters:
modules (Iterable[Module])
- Return type:
list[Parameter]
- class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]#
Bases:
objectContext manager that temporarily disables gradients for given modules.
Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.
- Parameters:
modules (Iterable[Module])
- class world_models.utils.dreamer_utils.Logger(log_dir, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', video_format='gif', video_fps=20, enable_tensorboard=False, enable_console=True, enable_jsonl=True, jsonl_filename='metrics.jsonl')[source]#
Bases:
objectExperiment logger for scalars and GIF rollouts using WandB.
Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.
- Parameters:
log_dir (str)
enable_wandb (bool)
wandb_api_key (str)
wandb_project (str)
wandb_entity (str)
video_format (str)
video_fps (int)
enable_tensorboard (bool)
enable_console (bool)
enable_jsonl (bool)
jsonl_filename (str)
- log_scalar(scalar, name, step_)[source]#
- Parameters:
scalar (Any)
name (str)
step_ (int)
- Return type:
None
- log_scalars(scalar_dict, step)[source]#
- Parameters:
scalar_dict (dict[str, Any])
step (int)
- Return type:
None
- log_videos(videos, step, max_videos_to_save=1, fps=None, video_title='video')[source]#
- Parameters:
videos (Any)
step (int)
max_videos_to_save (int)
fps (int | None)
video_title (str)
- Return type:
None
- world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]#
Compute TD(lambda) returns from imagined rewards, values, and discounts.
Implements backward recursion used by Dreamer actor/value objectives.
- Parameters:
rewards (Tensor)
values (Tensor)
discounts (Tensor)
td_lam (float)
last_value (Tensor)
- Return type:
Tensor
- world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]#
Initialize a tensor in-place from a truncated normal distribution.
Values are sampled from N(mean, std) and clipped to [a, b].
- Parameters:
tensor (Tensor)
mean (float)
std (float)
a (float)
b (float)
- Return type:
Tensor
- world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]#
Repeat each batch chunk multiple times while preserving chunk ordering.
Used in JEPA masking code to align context and target token batches.
- Parameters:
x (Tensor)
B (int)
repeat (int)
- Return type:
Tensor
- class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]#
Bases:
objectLearning-rate schedule with linear warmup followed by cosine decay.
Updates optimizer parameter-group LRs on each call to step().
- Parameters:
optimizer (Optimizer)
warmup_steps (int)
start_lr (float)
ref_lr (float)
T_max (int)
last_epoch (int)
final_lr (float)
- class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]#
Bases:
objectCosine scheduler for optimizer weight decay values.
Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.
- Parameters:
optimizer (Optimizer)
ref_wd (float)
T_max (int)
final_wd (float)
- world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]#
Measure CUDA execution time for a closure and return (result, elapsed_ms).
Falls back to -1 elapsed time when CUDA timing is unavailable.
- Parameters:
closure (Any)
log_timings (bool)
- Return type:
Tuple[Any, float]
- class world_models.utils.jepa_utils.CSVLogger(fname, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', *argv)[source]#
Bases:
objectLightweight CSV logger with per-column printf-style formatting and WandB support.
- Parameters:
fname (str)
enable_wandb (bool)
wandb_api_key (str)
wandb_project (str)
wandb_entity (str)
argv (Any)
- class world_models.utils.jepa_utils.AverageMeter[source]#
Bases:
objectTrack running statistics (val, avg, min, max, sum, count) for metrics.
- world_models.utils.jepa_utils.grad_logger(named_params)[source]#
Aggregate gradient norm statistics over model parameters for logging.
Also exposes first/last qkv-layer gradient norms when available.
- Parameters:
named_params (Any)
- Return type:
- world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]#
Initialize torch distributed process groups when environment supports it.
Returns (world_size, rank) and gracefully falls back to single-process mode.
- Parameters:
port (int)
rank_and_world_size (tuple)
- Return type:
Tuple[int, int]
- class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]#
Bases:
FunctionAutograd-aware all-gather operation across distributed workers.
Forward concatenates worker tensors; backward reduces and slices gradients.
- class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]#
Bases:
FunctionAutograd function that sums tensors across distributed workers in forward pass.
- class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]#
Bases:
FunctionAutograd function that all-reduces and averages tensors across workers.
Used to synchronize scalar losses for consistent distributed logging/training.
- world_models.utils.data_utils.create_efficient_dataloader(dataset, batch_size, num_workers=None, pin_memory=True, prefetch_factor=2, persistent_workers=True)[source]#
Create a memory-efficient and fast DataLoader.
- Parameters:
dataset (Dataset)
batch_size (int)
num_workers (int | None)
pin_memory (bool)
prefetch_factor (int)
persistent_workers (bool)
- Return type:
DataLoader
- world_models.utils.data_utils.prefetch_iterator(iterator, buffer_size=3)[source]#
Add prefetching to any iterator.
- Parameters:
iterator (Iterator)
buffer_size (int)
- Return type:
Iterator
- world_models.utils.jit_utils.jit_compile_function(func)[source]#
JIT compile a function for performance.
- Parameters:
func (Callable)
- Return type:
Callable
- world_models.utils.jit_utils.jit_compile_module(module)[source]#
JIT compile a PyTorch module.
- Parameters:
module (Module)
- Return type:
Module
- world_models.utils.memory_utils.apply_gradient_checkpointing(model, checkpoint_ratio=0.5)[source]#
Apply gradient checkpointing to reduce memory usage during training.
- Parameters:
model (Module)
checkpoint_ratio (float)
- Return type:
None
- world_models.utils.memory_utils.enable_mixed_precision(model, scaler=None)[source]#
Enable mixed precision training.
- Parameters:
model (Module)
scaler (GradScaler | None)
- Return type:
GradScaler
- world_models.utils.memory_utils.optimize_memory_efficient_ops()[source]#
Set PyTorch for memory-efficient operations.
- Return type:
None
Logging, metrics, and numerical-safety helpers for torchwm.
- world_models.utils.logging_utils.get_package_logger(name=None)[source]#
Return a logger under the
world_modelspackage namespace.- Parameters:
name (str | None)
- Return type:
Logger
- world_models.utils.logging_utils.setup_logging(name='world_models', level='INFO', log_file=None, fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s')[source]#
Set up structured package logging with optional file output.
- Parameters:
name (str) – Logger name to configure. Defaults to the package logger.
level (str | int) – Logging level name or numeric value.
log_file (str | None) – Optional file path for a file handler.
fmt (str) –
logging.Formatterformat string.
- Return type:
Logger
- class world_models.utils.logging_utils.MetricsLogger(log_dir, *, logger=None, enable_console=True, enable_jsonl=True, jsonl_filename='metrics.jsonl', enable_tensorboard=False, enable_wandb=False, wandb_api_key='', wandb_project='torchwm', wandb_entity='', run_name=None)[source]#
Bases:
objectFan-out metric logger for console, JSONL, TensorBoard, and W&B.
JSONL output is enabled by default because it is dependency-free and easy to reload for offline plots. TensorBoard and W&B are optional and activated only when requested and available/configured.
- Parameters:
log_dir (str)
logger (logging.Logger | None)
enable_console (bool)
enable_jsonl (bool)
jsonl_filename (str)
enable_tensorboard (bool)
enable_wandb (bool)
wandb_api_key (str)
wandb_project (str)
wandb_entity (str)
run_name (str | None)
- log(metrics, step, prefix=None)[source]#
Log scalar metrics to every enabled sink.
- Parameters:
metrics (Mapping[str, Any])
step (int)
prefix (str | None)
- Return type:
dict[str, Any]
- world_models.utils.logging_utils.collect_system_stats(device=None)[source]#
Collect CPU/GPU memory and CUDA utilization counters when available.
- Parameters:
device (device | str | None)
- Return type:
dict[str, float]
- world_models.utils.logging_utils.assert_finite_values(value, name='value')[source]#
Raise
FloatingPointErrorif any tensor contains NaN or Inf.- Parameters:
value (Any)
name (str)
- Return type:
Any
- world_models.utils.logging_utils.assert_finite(fn)[source]#
Decorator that validates tensor outputs from loss functions are finite.
- Parameters:
fn (Any)
- Return type:
Any
- world_models.utils.utils.load_yml_config(path)[source]#
- Parameters:
path (str)
- Return type:
AttrDict | None
- world_models.utils.utils.to_tensor_obs(image)[source]#
Converts the input np img to channel first 64x64 dim torch img.
- Parameters:
image (ndarray)
- Return type:
Tensor
- world_models.utils.utils.postprocess_img(image, depth)[source]#
Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])
- Parameters:
image (ndarray)
depth (int)
- Return type:
ndarray
- world_models.utils.utils.preprocess_img(image, depth)[source]#
Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !!
- Parameters:
image (Tensor)
depth (int)
- Return type:
None
- world_models.utils.utils.bottle(func, *tensors)[source]#
Evaluates a func that operates in N x D with inputs of shape N x T x D
- Parameters:
func (Any)
tensors (Tensor)
- Return type:
Tensor
- world_models.utils.utils.get_combined_params(*models)[source]#
Returns the combine parameter list of all the models given as input.
- Parameters:
models (Any)
- Return type:
list
- world_models.utils.utils.save_video(frames, path, name)[source]#
Saves a video containing frames.
- Accepts frames in either:
(T, C, H, W) float in [0,1]
(T, H, W, C) float in [0,1]
Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout.
- Parameters:
frames (Any)
path (str)
name (str)
- Return type:
str
- world_models.utils.utils.combine_videos(video_dir, output_name='combined.mp4', pattern='vid_*.mp4', fps=25, resize=True)[source]#
Combine all videos matching pattern in video_dir into a single MP4 file. Returns the output filepath (string).
Example
combine_videos(“results/planet”, output_name=”all_training.mp4”)
- Parameters:
video_dir (str)
output_name (str)
pattern (str)
fps (int)
resize (bool)
- Return type:
str
- world_models.utils.utils.ensure_results_dir_exists(results_dir)[source]#
Simple helper to validate a results directory exists. Raises FileNotFoundError if not present.
- Parameters:
results_dir (str)
- Return type:
None
- world_models.utils.utils.save_frames(target, pred_prior, pred_posterior, name, n_rows=5)[source]#
Save side-by-side target, prior-prediction, and posterior-prediction frames.
The function accepts tensors with optional time dimension and writes a PNG grid to
{name}.png. Spatial sizes are aligned per timestep before concatenation and values are normalized to[0, 1]when needed.- Parameters:
target (Tensor)
pred_prior (Tensor)
pred_posterior (Tensor)
name (str)
n_rows (int)
- Return type:
None
- world_models.utils.utils.get_mask(tensor, lengths)[source]#
Build a batch-first validity mask from sequence lengths.
tensormay be a tensor/array with shape(N, T, ...)or(N,). The returned mask marks valid timesteps with ones up to each element inlengthsand preserves device/dtype conventions from the input.- Parameters:
tensor (Any)
lengths (Any)
- Return type:
Tensor
- world_models.utils.utils.load_memory(path, device, *, trusted=False)[source]#
Loads an experience replay buffer.
Pickle can execute arbitrary code during unrestricted deserialization, so user-supplied replay buffers are always loaded with a restricted unpickler that only allows the replay buffer classes and numpy containers required by historical buffers. The
trustedargument is retained for backwards compatibility, but it no longer enables unrestricted pickle loading.Converts legacy list/.data formats into the current Memory(episodes) object.
- Parameters:
path (str)
device (device)
trusted (bool)
- Return type:
Any
- world_models.utils.utils.flatten_dict(data, sep='.', prefix='')[source]#
Flattens a nested dict into a single-level dict.
Example
{‘a’: 2, ‘b’: {‘c’: 20}} -> {‘a’: 2, ‘b.c’: 20}
- Parameters:
data (dict)
sep (str)
prefix (str)
- Return type:
dict
- world_models.utils.utils.normalize_frames_for_saving(frames)[source]#
Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.
- Parameters:
frames (Any)
- Return type:
ndarray
- class world_models.utils.utils.TensorBoardMetrics(path)[source]#
Bases:
objectPlots and (optionally) stores metrics for an experiment.
- Parameters:
path (str)
- world_models.utils.utils.apply_model(model, inputs, ignore_dim=None)[source]#
Placeholder helper for generic model application across input structures.
Currently not implemented; kept as an extension hook for future utility code.
- Parameters:
model (Any)
inputs (Any)
ignore_dim (Any)
- Return type:
None
- world_models.utils.utils.plot_metrics(metrics, path, prefix)[source]#
Render and save line plots for each metric series in a dictionary.
- Parameters:
metrics (dict)
path (str)
prefix (str)
- Return type:
None
- world_models.utils.utils.lineplot(xs, ys, title, path='', xaxis='episode')[source]#
Create a Plotly line plot for scalar, dict, or ensemble-series data.
Supports uncertainty-band plotting when ys is a 2D array.
- Parameters:
xs (ndarray | list)
ys (Any)
title (str)
path (str)
xaxis (str)
- Return type:
None
- class world_models.utils.utils.TorchImageEnvWrapper(env, bit_depth, observation_shape=None, act_rep=2)[source]#
Bases:
objectTorch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form.
- Parameters:
env (Any)
bit_depth (int)
observation_shape (Any)
act_rep (int)
- property observation_size: tuple[int, int, int]#
- property action_size: int#
- property max_episode_steps: int#
Return environment max episode steps (compatible with TimeLimit/spec).
- world_models.utils.utils.apply_masks(x, masks)[source]#
Gather token subsets from patch sequences using index masks.
Each mask selects token positions from x; selected groups are concatenated along the batch dimension.
- Parameters:
x (Tensor)
masks (list[Tensor])
- Return type:
Tensor
- world_models.utils.utils.visualize_latent_tsne(latents, labels=None, save_path=None, perplexity=30)[source]#
Visualize latent representations using t-SNE.
- Parameters:
latents (Tensor | ndarray) – torch.Tensor of shape (N, D) or numpy array
labels (ndarray | None) – optional list or array of labels for coloring
save_path (str | None) – path to save the plot (HTML for plotly)
perplexity (int) – t-SNE perplexity parameter
- Return type:
Figure
- world_models.utils.utils.visualize_latent_umap(latents, labels=None, save_path=None, n_neighbors=15)[source]#
Visualize latent representations using UMAP.
- Parameters:
latents (Tensor | ndarray) – torch.Tensor of shape (N, D) or numpy array
labels (ndarray | None) – optional list or array of labels for coloring
save_path (str | None) – path to save the plot (HTML for plotly)
n_neighbors (int) – UMAP n_neighbors parameter
- Return type:
Figure
- class world_models.utils.utils.StreamingVideoWriter(path, fps=20, frame_shape=None, format='mp4')[source]#
Bases:
objectA class for streaming video writing to save frames in real-time.
- Parameters:
path (str) – output video file path
fps (int) – frames per second
frame_shape (Any) – (height, width) of frames
format (str) – ‘mp4’ or ‘avi’