API Reference#
This reference is generated from source docstrings and grouped by subsystem.
Core Public APIs#
- class world_models.configs.DreamerConfig[source]#
Bases:
objectConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/Gym/Unity), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- class world_models.configs.JEPAConfig[source]#
Bases:
objectMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]#
Bases:
objectDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'#
- BATCH: int = 128#
- EPOCHS: int = 3#
- LR: float = 0.0002#
- IMG_SIZE: int = 32#
- CHANNELS: int = 3#
- PATCH: int = 4#
- WIDTH: int = 384#
- DEPTH: int = 6#
- HEADS: int = 6#
- DROP: float = 0.1#
- BETA_START: float = 0.0001#
- BETA_END: float = 0.02#
- TIMESTEPS: int = 1000#
- EMA: bool = True#
- EMA_DECAY: float = 0.999#
- WORKDIR: str = './dit_demo'#
- ROOT_PATH: str = './data'#
- world_models.configs.get_dit_config(**overrides)[source]#
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
- Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)
- world_models.envs.make_atari_env(env_id, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, **kwargs)[source]#
Create any Atari environment from Arcadic Learning Environment (ALE).
- Parameters:
env_id (str) – The id of the Atari environment to create.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The created Atari environment.
- Return type:
gym.Env
- world_models.envs.list_available_atari_envs()[source]#
Get a list of all available Atari environments in Arcadic Learning Environment (ALE).
- Returns:
List of available Atari environment IDs.
- Return type:
list[str]
- world_models.envs.make_atari_vector_env(game, num_envs, obs_type='rgb', frameskip=4, repeat_action_probability=0.25, full_action_space=False, max_episode_steps=None, seed=None, **kwargs)[source]#
Create vectorized Atari environments using ALE’s native AtariVectorEnv.
- Parameters:
game (str) – The name of the Atari game (e.g., “pong”, “breakout”).
num_envs (int) – Number of parallel environments.
obs_type (str) – The type of observation to return (“rgb” or “ram”).
frameskip (int) – The number of frames to skip between actions.
repeat_action_probability (float) – The probability of repeating the last action.
full_action_space (bool) – Whether to use the full action space.
max_episode_steps (Optional[int]) – Maximum number of steps per episode.
seed (Optional[int]) – Random seed for reproducibility.
**kwargs – Additional keyword arguments for environment configuration.
- Returns:
The vectorized Atari environment.
- Return type:
AtariVectorEnv
- world_models.envs.make_humanoid_env(version='v4', xml_file='humanoid.xml', forward_reward_weight=1.25, ctrl_cost_weight=0.1, contact_cost_weight=5e-07, healthy_reward=5.0, terminate_when_unhealthy=True, healthy_z_range=(1.0, 2.0), reset_noise_scale=0.01, exclude_current_positions_from_observation=True, include_cinert_in_observation=True, include_cvel_in_observation=True, include_qfrc_actuator_in_observation=True, include_cfrc_ext_in_observation=True)[source]#
Create a Humanoid environment with customizable parameters.
- Parameters:
version (str) – The version of the Humanoid environment (e.g., “v4”).
xml_file (str) – The XML file defining the Humanoid model.
forward_reward_weight (float) – Weight for the forward reward.
ctrl_cost_weight (float) – Weight for the control cost.
contact_cost_weight (float) – Weight for the contact cost.
healthy_reward (float) – Reward for being in a healthy state.
terminate_when_unhealthy (bool) – Whether to terminate the episode when unhealthy.
healthy_z_range (Tuple[float, float]) – The range of z-values considered healthy.
reset_noise_scale (float) – Scale of noise added during environment reset.
exclude_current_positions_from_observation (bool) – Whether to exclude current positions from observations.
include_cinert_in_observation (bool) – Whether to include inertia in observations.
include_cvel_in_observation (bool) – Whether to include velocity in observations.
include_qfrc_actuator_in_observation (bool) – Whether to include actuator forces in observations.
include_cfrc_ext_in_observation (bool) – Whether to include external forces in observations.
- Returns:
The created Humanoid environment.
- Return type:
gym.Env
- world_models.envs.make_half_cheetah_env(version='v4', forward_reward_weight=0.1, reset_noise_scale=0.1, exclude_current_positions_from_observation=True, render_mode='rgb_array')[source]#
Create a HalfCheetah environment with customizable parameters.
- Parameters:
version (str) – The version of the HalfCheetah environment (e.g., “v4”).
forward_reward_weight (float) – Weight for the forward reward.
reset_noise_scale (float) – Scale of noise added during environment reset.
exclude_current_positions_from_observation (bool) – Whether to exclude current positions from observations.
render_mode (str) – The render mode for the environment.
- Returns:
The created HalfCheetah environment.
- Return type:
gym.Env
- class world_models.envs.GymImageEnv(env, seed=0, size=(64, 64), render_mode='rgb_array')[source]#
Bases:
objectGym-like environment wrapper that always returns image observations.
Supports env IDs (string) and prebuilt env objects.
For vector observations, it synthesizes an RGB image so pixel-based world models can still train.
For discrete actions, it exposes a vector action space and maps by argmax.
- property observation_space#
- property action_space#
- property max_episode_steps#
- world_models.envs.make_gym_env(env, **kwargs)[source]#
Factory helper for generic Gym/Gymnasium environments.
- class world_models.envs.UnityMLAgentsEnv(file_name, behavior_name=None, seed=0, size=(64, 64), worker_id=0, base_port=5005, no_graphics=True, time_scale=20.0, quality_level=1, max_episode_steps=1000)[source]#
Bases:
objectGym-like wrapper for a Unity ML-Agents environment.
Notes: - Supports single-agent control. - Supports continuous action spaces. - Returns channel-first uint8 images in obs[“image”] for Dreamer-style pipelines.
- property observation_space#
- property action_space#
- property max_episode_steps#
- world_models.envs.make_unity_mlagents_env(**kwargs)[source]#
Factory helper for Unity ML-Agents environments.
- class world_models.envs.DeepMindControlEnv(name, seed, size=(64, 64), camera=None)[source]#
Bases:
objectGym-style adapter for DeepMind Control Suite tasks.
The wrapper exposes DMC observations and actions through Gym spaces and adds a rendered RGB image to each observation dict so image-based world model pipelines can train consistently across backends.
- property observation_space#
- property action_space#
- class world_models.envs.TimeLimit(env, duration)[source]#
Bases:
objectTerminate episodes after a fixed number of wrapper steps.
If the wrapped environment does not provide a discount flag at timeout, the wrapper injects a default discount of 1.0 for downstream learners.
- class world_models.envs.ActionRepeat(env, amount)[source]#
Bases:
objectRepeat each action for a fixed number of environment steps.
Rewards are accumulated and the loop stops early if the environment terminates, mirroring common action-repeat behavior in world model papers.
- class world_models.envs.NormalizeActions(env)[source]#
Bases:
objectExpose a normalized [-1, 1] action space for bounded continuous controls.
Incoming normalized actions are mapped back to the wrapped environment action bounds before stepping the environment.
- property action_space#
- class world_models.envs.ObsDict(env, key='obs')[source]#
Bases:
objectConvert scalar/array observations into a dictionary observation format.
This harmonizes outputs for code paths that expect keyed observations (for example {“image”: …} style world model inputs).
- property observation_space#
- property action_space#
- class world_models.envs.OneHotAction(env)[source]#
Bases:
objectWrap discrete-action environments to accept one-hot action vectors.
The wrapper validates one-hot inputs and converts them to integer action indices before forwarding to the underlying environment.
- property action_space#
- class world_models.envs.RewardObs(env)[source]#
Bases:
objectAugment observations with the latest scalar reward under obs[“reward”].
Useful for agents that consume reward as part of the observation stream during model learning or recurrent policy inference.
- property observation_space#
- class world_models.envs.ResizeImage(env, size=(64, 64))[source]#
Bases:
objectResize image-like observation entries to a target spatial size.
The wrapper discovers image keys from env.obs_space, applies nearest neighbor resizing, and updates the advertised observation space shapes.
- property obs_space#
Dreamer#
- world_models.models.dreamer.make_env(args)[source]#
Construct a Dreamer-compatible environment from DreamerConfig options.
Supports DMC, Gym/Gymnasium, and Unity ML-Agents backends and applies the standard wrapper stack: action repeat, action normalization, and time limit.
- world_models.models.dreamer.preprocess_obs(obs)[source]#
Convert raw uint8 image observations to Dreamer float input space.
Images are scaled from [0, 255] to roughly [-0.5, 0.5], matching the normalization expected by Dreamer encoders.
- class world_models.models.dreamer.Dreamer(args, obs_shape, action_size, device, restore=False)[source]#
Bases:
objectCore Dreamer training system combining world model, actor, and value nets.
This class owns model construction, replay sampling, imagination rollouts, loss computation, optimization steps, evaluation loops, and checkpoint I/O.
- class world_models.models.dreamer.DreamerAgent(config=None, **kwargs)[source]#
Bases:
objectHigh-level user API for running Dreamer experiments end to end.
It builds environments from config, initializes seeds and logging, instantiates Dreamer, and exposes simple train() / evaluate() methods.
- class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]#
Bases:
ModuleRecurrent State-Space Model used by Dreamer latent dynamics learning.
The RSSM maintains deterministic recurrent state and stochastic latent state, and provides transition/posterior updates plus rollout helpers.
- class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]#
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
Encodes image observations into compact embeddings consumed by the RSSM posterior update network.
- class world_models.vision.dreamer_decoder.TanhBijector[source]#
Bases:
TransformBijective tanh transform used to squash Gaussian actions to [-1, 1].
Provides stable inverse and log-determinant Jacobian computations for transformed-action distributions.
- property sign#
- class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]#
Bases:
ModuleConvolutional Dreamer decoder from latent features to image distributions.
Maps concatenated stochastic/deterministic states into transposed-conv outputs and returns a factorized Normal distribution over pixels.
- class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist)[source]#
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
The output distribution type is configurable (normal, binary, or raw tensor).
- class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]#
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- property name#
- class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]#
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
JEPA and ViT#
- class world_models.models.jepa_agent.JEPAAgent(config=None, **kwargs)[source]#
Bases:
objectConvenience interface for configuring and launching JEPA training runs.
Accepts a JEPAConfig plus keyword overrides, prepares output folders, and delegates execution to the JEPA training entrypoint.
- Parameters:
config (JEPAConfig | None)
- world_models.training.train_jepa.main(args, resume_preempt=False)[source]#
Run JEPA training using a nested config dict or JEPAConfig instance.
This entrypoint initializes distributed context, data pipeline, masking, models, optimizers/schedulers, checkpointing, and the full epoch loop.
- world_models.models.vit.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Generate fixed 2D sine/cosine positional embeddings on a square patch grid.
Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding.
- world_models.models.vit.get_2d_sincos_pos_embed_from_grid(embed_dim, grid)[source]#
Build 2D sine/cosine embeddings from precomputed meshgrid coordinates.
The final embedding concatenates independent encodings for vertical and horizontal coordinates.
- world_models.models.vit.get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Generate 1D sine/cosine positional embeddings for integer positions.
Useful for sequence-style positional encoding and as a building block for 2D embedding construction.
- world_models.models.vit.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#
Generate 1D sine/cosine positional embeddings from explicit positions.
Positions are projected onto a log-frequency basis and encoded with sine and cosine components.
- world_models.models.vit.drop_path(x, drop_prob=0.0, training=False)[source]#
Apply stochastic depth (DropPath) regularization to residual branches.
Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude.
- Parameters:
drop_prob (float)
training (bool)
- class world_models.models.vit.DropPath(drop_prob=None)[source]#
Bases:
ModuleModule wrapper around the functional drop_path stochastic depth utility.
- class world_models.models.vit.MLP(in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)[source]#
Bases:
ModuleTwo-layer feed-forward network used inside transformer blocks.
Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern.
- class world_models.models.vit.Attention(dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0)[source]#
Bases:
ModuleMulti-head self-attention block for token sequences.
Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout.
- class world_models.models.vit.Block(dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>)[source]#
Bases:
ModuleTransformer encoder block combining attention and MLP residual branches.
Each branch uses pre-normalization and optional stochastic depth.
- class world_models.models.vit.PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768)[source]#
Bases:
ModuleImage to Patch Embedding
- class world_models.models.vit.ConvEmbed(channels, strides, img_size=224, in_chans=3, batch_norm=True)[source]#
Bases:
Module3x3 Convolution stems for ViT following ViTC models
- class world_models.models.vit.VisionTransformerPredictor(num_patches, embed_dim=768, predictor_embed_dim=384, depth=6, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#
Bases:
ModuleVision Transformer
- class world_models.models.vit.VisionTransformer(img_size=[224], patch_size=16, in_chans=3, embed_dim=768, predictor_embed_dim=384, depth=12, predictor_depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, init_std=0.02, **kwargs)[source]#
Bases:
ModuleVision Transformer
- world_models.models.vit.vit_predictor(**kwargs)[source]#
Factory for a JEPA predictor transformer with sensible defaults.
- world_models.models.vit.vit_tiny(patch_size=16, **kwargs)[source]#
Factory for a tiny Vision Transformer encoder backbone.
- world_models.models.vit.vit_small(patch_size=16, **kwargs)[source]#
Factory for a small Vision Transformer encoder backbone.
- world_models.models.vit.vit_base(patch_size=16, **kwargs)[source]#
Factory for a base Vision Transformer encoder backbone.
- world_models.models.vit.vit_large(patch_size=16, **kwargs)[source]#
Factory for a large Vision Transformer encoder backbone.
- world_models.models.vit.vit_huge(patch_size=16, **kwargs)[source]#
Factory for a huge Vision Transformer encoder backbone.
- world_models.models.vit.vit_giant(patch_size=16, **kwargs)[source]#
Factory for a giant Vision Transformer encoder backbone.
- class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]#
Bases:
objectGenerate multi-block encoder and predictor masks for JEPA training.
For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.
- class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]#
Bases:
objectGenerate random context/prediction patch splits for masked training.
A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.
Diffusion#
- class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end, device)[source]#
Bases:
objectUtility class implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]#
Create sinusoidal timestep embeddings for diffusion conditioning.
Embeddings are scaled relative to configured diffusion timesteps and are consumed by the DiT conditioning MLP.
- class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]#
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
A strided convolution performs patch extraction and projects each patch to the transformer hidden dimension with additive positional embeddings.
- class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]#
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]#
Bases:
ModuleConditioned transformer block used inside the DiT backbone.
Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.
- class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]#
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]#
Datasets and Transforms#
- world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]#
Create CIFAR-10 dataset objects and a distributed-capable dataloader.
Returns the dataset, sampler, and loader configured with the provided transform/collator so callers can plug the loader directly into JEPA or diffusion training loops.
- world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]#
Build an ImageNet-1K dataset and dataloader with distributed sampling support.
This helper optionally restricts data to a subset file and returns the (dataset, dataloader, sampler) tuple used by training scripts.
- class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]#
Bases:
ImageFolderImageNet dataset wrapper with optional local copy/extract workflow.
The class extends torchvision.datasets.ImageFolder and can stage data from shared storage into local scratch space for faster multi-process training on cluster environments.
- class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]#
Bases:
objectView over an ImageNet dataset filtered by an explicit image-id list.
The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.
- property classes#
- world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]#
Copy and extract ImageNet archives to per-job local scratch storage.
In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.
- world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]#
Create an ImageFolder dataset loader for custom folder-structured datasets.
Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.
- Parameters:
val_split (float | None)
- world_models.transforms.transforms.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]#
Compose image augmentations and normalization for vision model training.
Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.
Memory and Controllers#
- class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]#
Bases:
objectFixed-size replay buffer for Dreamer image observations and transitions.
Stores (observation image, action, reward, terminal) tuples and supports sampling contiguous sequences used by world-model unroll training.
- class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]#
Bases:
objectRecords the agent’s interaction with the environment for a single episode. At termination, it converts all the data to Numpy arrays.
- property size#
- class world_models.memory.planet_memory.Memory(size=None)[source]#
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Episodes are stored as variable-length trajectories and sampled as sub-sequences with optional time-major formatting for sequence models.
- property size#
- class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]#
Bases:
objectModel-predictive controller that plans actions with the RSSM latent model.
The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.
Utilities#
- world_models.utils.dreamer_utils.get_parameters(modules)[source]#
Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters
- Parameters:
modules (Iterable[Module])
- class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]#
Bases:
objectContext manager that temporarily disables gradients for given modules.
Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.
- Parameters:
modules (Iterable[Module])
- class world_models.utils.dreamer_utils.Logger(log_dir, n_logged_samples=10, summary_writer=None)[source]#
Bases:
objectExperiment logger for scalars and GIF rollouts using TensorBoardX.
Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.
- world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]#
Compute TD(lambda) returns from imagined rewards, values, and discounts.
Implements backward recursion used by Dreamer actor/value objectives.
- world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]#
Initialize a tensor in-place from a truncated normal distribution.
Values are sampled from N(mean, std) and clipped to [a, b].
- world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]#
Repeat each batch chunk multiple times while preserving chunk ordering.
Used in JEPA masking code to align context and target token batches.
- class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]#
Bases:
objectLearning-rate schedule with linear warmup followed by cosine decay.
Updates optimizer parameter-group LRs on each call to step().
- class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]#
Bases:
objectCosine scheduler for optimizer weight decay values.
Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.
- world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]#
Measure CUDA execution time for a closure and return (result, elapsed_ms).
Falls back to -1 elapsed time when CUDA timing is unavailable.
- class world_models.utils.jepa_utils.CSVLogger(fname, *argv)[source]#
Bases:
objectLightweight CSV logger with per-column printf-style formatting.
- class world_models.utils.jepa_utils.AverageMeter[source]#
Bases:
objectTrack running statistics (val, avg, min, max, sum, count) for metrics.
- world_models.utils.jepa_utils.grad_logger(named_params)[source]#
Aggregate gradient norm statistics over model parameters for logging.
Also exposes first/last qkv-layer gradient norms when available.
- world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]#
Initialize torch distributed process groups when environment supports it.
Returns (world_size, rank) and gracefully falls back to single-process mode.
- class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]#
Bases:
FunctionAutograd-aware all-gather operation across distributed workers.
Forward concatenates worker tensors; backward reduces and slices gradients.
- class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]#
Bases:
FunctionAutograd function that sums tensors across distributed workers in forward pass.
- class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]#
Bases:
FunctionAutograd function that all-reduces and averages tensors across workers.
Used to synchronize scalar losses for consistent distributed logging/training.
- world_models.utils.utils.to_tensor_obs(image)[source]#
Converts the input np img to channel first 64x64 dim torch img.
- world_models.utils.utils.postprocess_img(image, depth)[source]#
Postprocess an image observation for storage. From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255])
- world_models.utils.utils.preprocess_img(image, depth)[source]#
Preprocesses an observation inplace. From float32 Tensor [0, 255] to [-0.5, 0.5] Also adds some noise to the observations !!
- world_models.utils.utils.bottle(func, *tensors)[source]#
Evaluates a func that operates in N x D with inputs of shape N x T x D
- world_models.utils.utils.get_combined_params(*models)[source]#
Returns the combine parameter list of all the models given as input.
- world_models.utils.utils.save_video(frames, path, name)[source]#
Saves a video containing frames.
- Accepts frames in either:
(T, C, H, W) float in [0,1]
(T, H, W, C) float in [0,1]
Produces {path}/{name}.mp4 and a debug PNG {path}/{name}_debug_frame.png with per-channel statistics printed to stdout.
- world_models.utils.utils.combine_videos(video_dir, output_name='combined.mp4', pattern='vid_*.mp4', fps=25, resize=True)[source]#
Combine all videos matching pattern in video_dir into a single MP4 file. Returns the output filepath (string).
Example
combine_videos(“results/planet”, output_name=”all_training.mp4”)
- world_models.utils.utils.ensure_results_dir_exists(results_dir)[source]#
Simple helper to validate a results directory exists. Raises FileNotFoundError if not present.
- world_models.utils.utils.save_frames(target, pred_prior, pred_posterior, name, n_rows=5)[source]#
Save side-by-side target, prior-prediction, and posterior-prediction frames.
The function accepts tensors with optional time dimension and writes a PNG grid to
{name}.png. Spatial sizes are aligned per timestep before concatenation and values are normalized to[0, 1]when needed.
- world_models.utils.utils.get_mask(tensor, lengths)[source]#
Build a batch-first validity mask from sequence lengths.
tensormay be a tensor/array with shape(N, T, ...)or(N,). The returned mask marks valid timesteps with ones up to each element inlengthsand preserves device/dtype conventions from the input.
- world_models.utils.utils.load_memory(path, device)[source]#
Loads an experience replay buffer (backwards-compatible with older pickle formats). Converts legacy list/.data formats into the current Memory(episodes) object.
- world_models.utils.utils.flatten_dict(data, sep='.', prefix='')[source]#
Flattens a nested dict into a single-level dict.
Example
{‘a’: 2, ‘b’: {‘c’: 20}} -> {‘a’: 2, ‘b.c’: 20}
- world_models.utils.utils.normalize_frames_for_saving(frames)[source]#
Ensure frames are in shape (T, H, W, 3) with float values in [0,1]. Handles inputs in (T, C, H, W) or (T, H, W, C), repeats single-channel -> RGB, drops alpha if present, and maps [-0.5,0.5] -> [0,1] when detected.
- class world_models.utils.utils.TensorBoardMetrics(path)[source]#
Bases:
objectPlots and (optionally) stores metrics for an experiment.
- world_models.utils.utils.apply_model(model, inputs, ignore_dim=None)[source]#
Placeholder helper for generic model application across input structures.
Currently not implemented; kept as an extension hook for future utility code.
- world_models.utils.utils.plot_metrics(metrics, path, prefix)[source]#
Render and save line plots for each metric series in a dictionary.
- world_models.utils.utils.lineplot(xs, ys, title, path='', xaxis='episode')[source]#
Create a Plotly line plot for scalar, dict, or ensemble-series data.
Supports uncertainty-band plotting when ys is a 2D array.
- class world_models.utils.utils.TorchImageEnvWrapper(env, bit_depth, observation_shape=None, act_rep=2)[source]#
Bases:
objectTorch Env Wrapper that wraps a gym env and makes interactions using Tensors. Also returns observations in image form.
- property observation_size#
- property action_size#
- property max_episode_steps#
Return environment max episode steps (compatible with TimeLimit/spec).