API Reference#

This reference is generated from source docstrings and grouped by subsystem.

Core Public APIs#

class world_models.configs.DreamerConfig[source]#

Bases: object

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/Gym/Unity), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

class world_models.configs.JEPAConfig[source]#

Bases: object

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]#
Return type:

Dict[str, Dict[str, Any]]

class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#

Bases: object

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'#
BATCH: int = 128#
EPOCHS: int = 3#
LR: float = 0.0002#
IMG_SIZE: int = 32#
CHANNELS: int = 3#
PATCH: int = 4#
WIDTH: int = 384#
DEPTH: int = 6#
HEADS: int = 6#
DROP: float = 0.1#
BETA_START: float = 0.0001#
BETA_END: float = 0.02#
TIMESTEPS: int = 1000#
EMA: bool = True#
EMA_DECAY: float = 0.999#
WORKDIR: str = './dit_demo'#
ROOT_PATH: str = './data'#
world_models.configs.get_dit_config(**overrides)[source]#

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)

world_models.envs.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#

Create any Atari environment from Arcadic Learning Environment (ALE).

Parameters:
  • env_id (str) – The id of the Atari environment to create.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The created Atari environment.

Return type:

gym.Env

world_models.envs.list_available_atari_envs()[source]#

Get a list of all available Atari environments in Arcadic Learning Environment (ALE).

Returns:

List of available Atari environment IDs.

Return type:

list[str]

world_models.envs.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#

Create vectorized Atari environments using ALE’s native AtariVectorEnv.

Parameters:
  • game (str) – The name of the Atari game (e.g., “pong”, “breakout”).

  • num_envs (int) – Number of parallel environments.

  • obs_type (str) – The type of observation to return (“rgb” or “ram”).

  • frameskip (int) – The number of frames to skip between actions.

  • repeat_action_probability (float) – The probability of repeating the last action.

  • full_action_space (bool) – Whether to use the full action space.

  • max_episode_steps (Optional[int]) – Maximum number of steps per episode.

  • seed (Optional[int]) – Random seed for reproducibility.

  • **kwargs – Additional keyword arguments for environment configuration.

Returns:

The vectorized Atari environment.

Return type:

AtariVectorEnv

world_models.envs.make_humanoid_env(version='v4', xml_file='humanoid.xml', forward_reward_weight=1.25, ctrl_cost_weight=0.1, contact_cost_weight=5e-07, healthy_reward=5.0, terminate_when_unhealthy=True, healthy_z_range=(1.0, 2.0), reset_noise_scale=0.01, exclude_current_positions_from_observation=True, include_cinert_in_observation=True, include_cvel_in_observation=True, include_qfrc_actuator_in_observation=True, include_cfrc_ext_in_observation=True)[source]#

Create a Humanoid environment with customizable parameters.

Parameters:
  • version (str) – The version of the Humanoid environment (e.g., “v4”).

  • xml_file (str) – The XML file defining the Humanoid model.

  • forward_reward_weight (float) – Weight for the forward reward.

  • ctrl_cost_weight (float) – Weight for the control cost.

  • contact_cost_weight (float) – Weight for the contact cost.

  • healthy_reward (float) – Reward for being in a healthy state.

  • terminate_when_unhealthy (bool) – Whether to terminate the episode when unhealthy.

  • healthy_z_range (Tuple[float, float]) – The range of z-values considered healthy.

  • reset_noise_scale (float) – Scale of noise added during environment reset.

  • exclude_current_positions_from_observation (bool) – Whether to exclude current positions from observations.

  • include_cinert_in_observation (bool) – Whether to include inertia in observations.

  • include_cvel_in_observation (bool) – Whether to include velocity in observations.

  • include_qfrc_actuator_in_observation (bool) – Whether to include actuator forces in observations.

  • include_cfrc_ext_in_observation (bool) – Whether to include external forces in observations.

Returns:

The created Humanoid environment.

Return type:

gym.Env

world_models.envs.make_half_cheetah_env(version='v4', forward_reward_weight=0.1, reset_noise_scale=0.1, exclude_current_positions_from_observation=True, render_mode='rgb_array')[source]#

Create a HalfCheetah environment with customizable parameters.

Parameters:
  • version (str) – The version of the HalfCheetah environment (e.g., “v4”).

  • forward_reward_weight (float) – Weight for the forward reward.

  • reset_noise_scale (float) – Scale of noise added during environment reset.

  • exclude_current_positions_from_observation (bool) – Whether to exclude current positions from observations.

  • render_mode (str) – The render mode for the environment.

Returns:

The created HalfCheetah environment.

Return type:

gym.Env

class world_models.envs.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#

Bases: object

Gym-like environment wrapper that always returns image observations.

  • Supports env IDs (string) and prebuilt env objects.

  • For vector observations, it synthesizes an RGB image so pixel-based world models can still train.

  • For discrete actions, it exposes a vector action space and maps by argmax.

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.make_gym_env(env, **kwargs)[source]#

Factory helper for generic Gym/Gymnasium environments.

class world_models.envs.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#

Bases: object

Gym-like wrapper for a Unity ML-Agents environment.

Notes: - Supports single-agent control. - Supports continuous action spaces. - Returns channel-first uint8 images in obs[“image”] for Dreamer-style pipelines.

property observation_space#
property action_space#
property max_episode_steps#
reset()[source]#
step(action)[source]#
render(*args, **kwargs)[source]#
close()[source]#
world_models.envs.make_unity_mlagents_env(**kwargs)[source]#

Factory helper for Unity ML-Agents environments.

class world_models.envs.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#

Bases: object

Gym-style adapter for DeepMind Control Suite tasks.

The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.

property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
render(*args, **kwargs)[source]#
class world_models.envs.TimeLimit(env, duration)[source]#

Bases: object

Terminate episodes after a fixed number of wrapper steps.

If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.

step(action)[source]#
reset()[source]#
class world_models.envs.ActionRepeat(env, amount)[source]#

Bases: object

Repeat each action for a fixed number of environment steps.

Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.

step(action)[source]#
class world_models.envs.NormalizeActions(env)[source]#

Bases: object

Expose a normalized [-1, 1] action space for bounded continuous controls.

Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.

property action_space#
step(action)[source]#
class world_models.envs.ObsDict(env, key='obs')[source]#

Bases: object

Convert scalar/array observations into a dictionary observation format.

This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).

property observation_space#
property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.OneHotAction(env)[source]#

Bases: object

Wrap discrete-action environments to accept one-hot action vectors.

The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.

property action_space#
step(action)[source]#
reset()[source]#
class world_models.envs.RewardObs(env)[source]#

Bases: object

Augment observations with the latest scalar reward under obs[“reward”].

Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.

property observation_space#
step(action)[source]#
reset()[source]#
class world_models.envs.ResizeImage(env, size=(64, 64))[source]#

Bases: object

Resize image-like observation entries to a target spatial size.

The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.RenderImage(env, key='image')[source]#

Bases: object

Inject RGB renders from env.render(“rgb_array”) into observations.

This is useful when the base environment returns non-image observations but a rendered camera view is needed for world-model training.

property obs_space#
step(action)[source]#
reset()[source]#
class world_models.envs.SelectAction(env, key)[source]#

Bases: Wrapper

Gym wrapper for dictionary actions that forwards a selected key only.

This enables integration with policies that emit action dicts while the environment expects a single tensor/array action payload.

step(action)[source]#

Dreamer#

world_models.models.dreamer.make_env(args)[source]#

Construct a Dreamer-compatible environment from DreamerConfig options.

Supports DMC, Gym/Gymnasium, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit.

world_models.models.dreamer.preprocess_obs(obs)[source]#

Convert raw uint8 image observations to Dreamer float input space.

Images are scaled from [0, 255] to roughly [-0.5, 0.5], matching the normalization expected by Dreamer encoders.

class world_models.models.dreamer.Dreamer(args, obs_shape, action_size, device, restore=False)[source]#

Bases: object

Core Dreamer training system combining world model, actor, and value nets.

This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O.

world_model_loss(obs, acs, rews, nonterms)[source]#
actor_loss()[source]#
value_loss()[source]#
train_one_batch()[source]#
act_with_world_model(obs, prev_state, prev_action, explore=False)[source]#
act_and_collect_data(env, collect_steps)[source]#
evaluate(env, eval_episodes, render=False)[source]#
collect_random_episodes(env, seed_steps)[source]#
save(save_path)[source]#
restore_checkpoint(ckpt_path)[source]#
class world_models.models.dreamer.DreamerAgent(config=None, **kwargs)[source]#

Bases: object

High-level user API for running Dreamer experiments end to end.

It builds environments from config, initializes seeds and logging, instantiates Dreamer, and exposes simple train() / evaluate() methods.

train(total_steps=None)[source]#
evaluate()[source]#
class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#

Bases: Module

Recurrent State-Space Model used by Dreamer latent dynamics learning.

The RSSM maintains deterministic recurrent state and stochastic latent state, and provides transition/posterior updates plus rollout helpers.

init_state(batch_size, device)[source]#
get_dist(mean, std)[source]#
observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]#
imagine_step(prev_state, prev_action, nonterm=1.0)[source]#
observe_rollout(obs_embed, actions, nonterms, prev_state, horizon)[source]#
imagine_rollout(actor, prev_state, horizon)[source]#
stack_states(states, dim=0)[source]#
detach_state(state)[source]#
seq_to_batch(state)[source]#
class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#

Bases: Module

Convolutional observation encoder used by Dreamer world models.

Encodes image observations into compact embeddings consumed by the RSSM posterior update network.

forward(inputs)[source]#
class world_models.vision.dreamer_decoder.TanhBijector[source]#

Bases: Transform

Bijective tanh transform used to squash Gaussian actions to [-1, 1].

Provides stable inverse and log-determinant Jacobian computations for transformed-action distributions.

property sign#
atanh(x)[source]#
log_abs_det_jacobian(x, y)[source]#
class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#

Bases: Module

Convolutional Dreamer decoder from latent features to image distributions.

Maps concatenated stochastic/deterministic states into transposed-conv outputs and returns a factorized Normal distribution over pixels.

forward(features)[source]#
class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist)[source]#

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

The output distribution type is configurable (normal, binary, or raw tensor).

forward(features)[source]#
class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]#

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

property name#
mean()[source]#
mode()[source]#
entropy()[source]#
sample()[source]#
class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

forward(features, deter=False)[source]#
add_exploration(action, action_noise=0.3)[source]#

JEPA and ViT#

class world_models.models.jepa_agent.JEPAAgent(config=None, **kwargs)[source]#

Bases: object

Convenience interface for configuring and launching JEPA training runs.

Accepts a JEPAConfig plus keyword overrides, prepares output folders, and delegates execution to the JEPA training entrypoint.

Parameters:

config (JEPAConfig | None)

train()[source]#
world_models.training.train_jepa.main(args, resume_preempt=False)[source]#

Run JEPA training using a nested config dict or JEPAConfig instance.

This entrypoint initializes distributed context, data pipeline, masking, models, optimizers/schedulers, checkpointing, and the full epoch loop.

world_models.models.vit.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate fixed 2D sine/cosine positional embeddings on a square patch grid.

Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding.

world_models.models.vit.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)[source]#

Build 2D sine/cosine embeddings from precomputed meshgrid coordinates.

The final embedding concatenates independent encodings for vertical and horizontal coordinates.

world_models.models.vit.get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#

Generate 1D sine/cosine positional embeddings for integer positions.

Useful for sequence-style positional encoding and as a building block for 2D embedding construction.

world_models.models.vit.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#

Generate 1D sine/cosine positional embeddings from explicit positions.

Positions are projected onto a log-frequency basis and encoded with sine and cosine components.

world_models.models.vit.drop_path(x, drop_prob=0.0, training=False)[source]#

Apply stochastic depth (DropPath) regularization to residual branches.

Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude.

Parameters:
  • drop_prob (float)

  • training (bool)

class world_models.models.vit.DropPath(drop_prob=None)[source]#

Bases: Module

Module wrapper around the functional drop_path stochastic depth utility.

forward(x)[source]#
class world_models.models.vit.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#

Bases: Module

Two-layer feed-forward network used inside transformer blocks.

Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern.

forward(x)[source]#
class world_models.models.vit.Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#

Bases: Module

Multi-head self-attention block for token sequences.

Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout.

forward(x)[source]#
class world_models.models.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#

Bases: Module

Transformer encoder block combining attention and MLP residual branches.

Each branch uses pre-normalization and optional stochastic depth.

forward(x)[source]#
class world_models.models.vit.PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768)[source]#

Bases: Module

Image to Patch Embedding

forward(x)[source]#
class world_models.models.vit.ConvEmbed(channels, strides, img_size=224, in_chans=3, batch_norm=True)[source]#

Bases: Module

3x3 Convolution stems for ViT following ViTC models

forward(x)[source]#
class world_models.models.vit.VisionTransformerPredictor(num_patches, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

fix_init_weight()[source]#
forward(x, masks_x, masks)[source]#
class world_models.models.vit.VisionTransformer(img_size=[224], patch_size=16, in_chans=3, embed_dim=768, predictor_embed_dim=384, depth=12, predictor_depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#

Bases: Module

Vision Transformer

fix_init_weight()[source]#
forward(x, masks=None)[source]#
interpolate_pos_encoding(x, pos_embed)[source]#
world_models.models.vit.vit_predictor(**kwargs)[source]#

Factory for a JEPA predictor transformer with sensible defaults.

world_models.models.vit.vit_tiny(patch_size=16, **kwargs)[source]#

Factory for a tiny Vision Transformer encoder backbone.

world_models.models.vit.vit_small(patch_size=16, **kwargs)[source]#

Factory for a small Vision Transformer encoder backbone.

world_models.models.vit.vit_base(patch_size=16, **kwargs)[source]#

Factory for a base Vision Transformer encoder backbone.

world_models.models.vit.vit_large(patch_size=16, **kwargs)[source]#

Factory for a large Vision Transformer encoder backbone.

world_models.models.vit.vit_huge(patch_size=16, **kwargs)[source]#

Factory for a huge Vision Transformer encoder backbone.

world_models.models.vit.vit_giant(patch_size=16, **kwargs)[source]#

Factory for a giant Vision Transformer encoder backbone.

class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]#

Bases: object

Generate multi-block encoder and predictor masks for JEPA training.

For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.

step()[source]#
class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]#

Bases: object

Generate random context/prediction patch splits for masked training.

A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.

step()[source]#

Diffusion#

class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end, device)[source]#

Bases: object

Utility class implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

q_sample(x_start, t, noise=None)[source]#
p_sample(model, x_t, t)[source]#
sample(model, n, img_size, channels)[source]#
world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]#

Create sinusoidal timestep embeddings for diffusion conditioning.

Embeddings are scaled relative to configured diffusion timesteps and are consumed by the DiT conditioning MLP.

class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

A strided convolution performs patch extraction and projects each patch to the transformer hidden dimension with additive positional embeddings.

forward(x)[source]#
class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

forward(x)[source]#
class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]#

Bases: Module

Conditioned transformer block used inside the DiT backbone.

Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.

forward(x, t_emb)[source]#
class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

forward(x, t)[source]#
classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#

Datasets and Transforms#

world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]#

Create CIFAR-10 dataset objects and a distributed-capable dataloader.

Returns the dataset, sampler, and loader configured with the provided transform/collator so callers can plug the loader directly into JEPA or diffusion training loops.

world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]#

Build an ImageNet-1K dataset and dataloader with distributed sampling support.

This helper optionally restricts data to a subset file and returns the (dataset, dataloader, sampler) tuple used by training scripts.

class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]#

Bases: ImageFolder

ImageNet dataset wrapper with optional local copy/extract workflow.

The class extends torchvision.datasets.ImageFolder and can stage data from shared storage into local scratch space for faster multi-process training on cluster environments.

class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]#

Bases: object

View over an ImageNet dataset filtered by an explicit image-id list.

The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.

filter_dataset_(subset_file)[source]#

Filter self.dataset to a subset

property classes#
world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]#

Copy and extract ImageNet archives to per-job local scratch storage.

In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.

world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]#

Create an ImageFolder dataset loader for custom folder-structured datasets.

Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.

Parameters:

val_split (float | None)

world_models.transforms.transforms.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]#

Compose image augmentations and normalization for vision model training.

Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.

class world_models.transforms.transforms.GaussianBlur(p=0.5, radius_min=0.1, radius_max=2.0)[source]#

Bases: object

Probabilistic Gaussian blur augmentation for PIL images.

Applies blur with random radius in a configurable range when sampled.

Memory and Controllers#

class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#

Bases: object

Fixed-size replay buffer for Dreamer image observations and transitions.

Stores (observation image, action, reward, terminal) tuples and supports sampling contiguous sequences used by world-model unroll training.

add(obs, ac, rew, done)[source]#
sample()[source]#
class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]#

Bases: object

Records the agent’s interaction with the environment for a single episode. At termination, it converts all the data to Numpy arrays.

property size#
append(obs, act, reward, terminal)[source]#
terminate(obs)[source]#
class world_models.memory.planet_memory.Memory(size=None)[source]#

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Episodes are stored as variable-length trajectories and sampled as sub-sequences with optional time-major formatting for sequence models.

property size#
append(episodes)[source]#
Parameters:

episodes (list[Episode])

sample(batch_size, tracelen=1, time_first=False)[source]#
class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#

Bases: object

Model-predictive controller that plans actions with the RSSM latent model.

The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.

reset()[source]#
poll(observation, explore=False)[source]#

Utilities#

world_models.utils.dreamer_utils.get_parameters(modules)[source]#

Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]#

Bases: object

Context manager that temporarily disables gradients for given modules.

Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.Logger(log_dir, n_logged_samples=10, summary_writer=None)[source]#

Bases: object

Experiment logger for scalars and GIF rollouts using TensorBoardX.

Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.

log_scalar(scalar, name, step_)[source]#
log_scalars(scalar_dict, step)[source]#
log_videos(videos, step, max_videos_to_save=1, fps=20, video_title='video')[source]#
dump_scalars_to_pickle(metrics, step, log_title=None)[source]#
flush()[source]#
world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]#

Compute TD(lambda) returns from imagined rewards, values, and discounts.

Implements backward recursion used by Dreamer actor/value objectives.

world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]#

Initialize a tensor in-place from a truncated normal distribution.

Values are sampled from N(mean, std) and clipped to [a, b].

world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]#

Repeat each batch chunk multiple times while preserving chunk ordering.

Used in JEPA masking code to align context and target token batches.

class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]#

Bases: object

Learning-rate schedule with linear warmup followed by cosine decay.

Updates optimizer parameter-group LRs on each call to step().

step()[source]#
class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]#

Bases: object

Cosine scheduler for optimizer weight decay values.

Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.

step()[source]#
world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]#

Measure CUDA execution time for a closure and return (result, elapsed_ms).

Falls back to -1 elapsed time when CUDA timing is unavailable.

class world_models.utils.jepa_utils.CSVLogger(fname, *argv)[source]#

Bases: object

Lightweight CSV logger with per-column printf-style formatting.

log(*argv)[source]#
class world_models.utils.jepa_utils.AverageMeter[source]#

Bases: object

Track running statistics (val, avg, min, max, sum, count) for metrics.

reset()[source]#
update(val, n=1)[source]#
world_models.utils.jepa_utils.grad_logger(named_params)[source]#

Aggregate gradient norm statistics over model parameters for logging.

Also exposes first/last qkv-layer gradient norms when available.

world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]#

Initialize torch distributed process groups when environment supports it.

Returns (world_size, rank) and gracefully falls back to single-process mode.

class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]#

Bases: Function

Autograd-aware all-gather operation across distributed workers.

Forward concatenates worker tensors; backward reduces and slices gradients.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]#

Bases: Function

Autograd function that sums tensors across distributed workers in forward pass.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]#

Bases: Function

Autograd function that all-reduces and averages tensors across workers.

Used to synchronize scalar losses for consistent distributed logging/training.

static forward(ctx, x)[source]#
static backward(ctx, grads)[source]#
class world_models.utils.utils.AttrDict[source]#

Bases: dict

world_models.utils.utils.load_yml_config(path)[source]#
world_models.utils.utils.to_tensor_obs(image)[source]#

Converts the input np img to channel first 64x64 dim torch img.

world_models.utils.utils.postprocess_img(image, depth)[source]#

Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])

world_models.utils.utils.preprocess_img(image, depth)[source]#

Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !!

world_models.utils.utils.bottle(func, *tensors)[source]#

Evaluates a func that operates in N x D with inputs of shape N x T x D

world_models.utils.utils.get_combined_params(*models)[source]#

Returns the combine parameter list of all the models given as input.

world_models.utils.utils.save_video(frames, path, name)[source]#

Saves a video containing frames.

Accepts frames in either:
  • (T, C, H, W) float in [0,1]

  • (T, H, W, C) float in [0,1]

Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout.

world_models.utils.utils.combine_videos(video_dir, output_name='combined.mp4', pattern='vid_*.mp4', fps=25, resize=True)[source]#

Combine all videos matching pattern in video_dir into a single MP4 file. Returns the output filepath (string).

Example

combine_videos(“results/planet”, output_name=”all_training.mp4”)

world_models.utils.utils.ensure_results_dir_exists(results_dir)[source]#

Simple helper to validate a results directory exists. Raises FileNotFoundError if not present.

world_models.utils.utils.save_frames(target, pred_prior, pred_posterior, name, n_rows=5)[source]#

Save side-by-side target, prior-prediction, and posterior-prediction frames.

The function accepts tensors with optional time dimension and writes a PNG grid to {name}.png. Spatial sizes are aligned per timestep before concatenation and values are normalized to [0, 1] when needed.

world_models.utils.utils.get_mask(tensor, lengths)[source]#

Build a batch-first validity mask from sequence lengths.

tensor may be a tensor/array with shape (N, T, ...) or (N,). The returned mask marks valid timesteps with ones up to each element in lengths and preserves device/dtype conventions from the input.

world_models.utils.utils.load_memory(path, device)[source]#

Loads an experience replay buffer (backwards-compatible with older pickle formats). Converts legacy list/.data formats into the current Memory(episodes) object.

world_models.utils.utils.flatten_dict(data, sep='.', prefix='')[source]#

Flattens a nested dict into a single-level dict.

Example

{‘a’: 2, ‘b’: {‘c’: 20}} -> {‘a’: 2, ‘b.c’: 20}

world_models.utils.utils.normalize_frames_for_saving(frames)[source]#

Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.

class world_models.utils.utils.TensorBoardMetrics(path)[source]#

Bases: object

Plots and (optionally) stores metrics for an experiment.

assign_type(key, val)[source]#
update(metrics)[source]#
Parameters:

metrics (dict)

world_models.utils.utils.apply_model(model, inputs, ignore_dim=None)[source]#

Placeholder helper for generic model application across input structures.

Currently not implemented; kept as an extension hook for future utility code.

world_models.utils.utils.plot_metrics(metrics, path, prefix)[source]#

Render and save line plots for each metric series in a dictionary.

world_models.utils.utils.lineplot(xs, ys, title, path='', xaxis='episode')[source]#

Create a Plotly line plot for scalar, dict, or ensemble-series data.

Supports uncertainty-band plotting when ys is a 2D array.

class world_models.utils.utils.TorchImageEnvWrapper(env, bit_depth, observation_shape=None, act_rep=2)[source]#

Bases: object

Torch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form.

reset()[source]#
step(u)[source]#
render()[source]#
close()[source]#
property observation_size#
property action_size#
sample_random_action()[source]#
property max_episode_steps#

Return environment max episode steps (compatible with TimeLimit/spec).

world_models.utils.utils.apply_masks(x, masks)[source]#

Gather token subsets from patch sequences using index masks.

Each mask selects token positions from x; selected groups are concatenated along the batch dimension.