API Reference
This reference is generated from source docstrings and grouped by subsystem.
Core Public APIs
- class world_models.configs.DreamerConfig[source]
Bases:
objectConfiguration container for Dreamer training, evaluation, and environment setup.
This class centralizes environment backend selection (DMC/Gym/Unity), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.
- class world_models.configs.JEPAConfig[source]
Bases:
objectMinimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.
- class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]
Bases:
objectDefault configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.
- Parameters:
DATASET (str)
BATCH (int)
EPOCHS (int)
LR (float)
IMG_SIZE (int)
CHANNELS (int)
PATCH (int)
WIDTH (int)
DEPTH (int)
HEADS (int)
DROP (float)
BETA_START (float)
BETA_END (float)
TIMESTEPS (int)
EMA (bool)
EMA_DECAY (float)
WORKDIR (str)
ROOT_PATH (str)
- DATASET: str = 'CIFAR10'
- BATCH: int = 128
- EPOCHS: int = 3
- LR: float = 0.0002
- IMG_SIZE: int = 32
- CHANNELS: int = 3
- PATCH: int = 4
- WIDTH: int = 384
- DEPTH: int = 6
- HEADS: int = 6
- DROP: float = 0.1
- BETA_START: float = 0.0001
- BETA_END: float = 0.02
- TIMESTEPS: int = 1000
- EMA: bool = True
- EMA_DECAY: float = 0.999
- WORKDIR: str = './dit_demo'
- ROOT_PATH: str = './data'
Dreamer
- class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]
Bases:
ModuleRecurrent State-Space Model used by Dreamer latent dynamics learning.
The RSSM maintains deterministic recurrent state and stochastic latent state, and provides transition/posterior updates plus rollout helpers.
- class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]
Bases:
ModuleConvolutional observation encoder used by Dreamer world models.
Encodes image observations into compact embeddings consumed by the RSSM posterior update network.
- class world_models.vision.dreamer_decoder.TanhBijector[source]
Bases:
TransformBijective tanh transform used to squash Gaussian actions to [-1, 1].
Provides stable inverse and log-determinant Jacobian computations for transformed-action distributions.
- property sign
- class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]
Bases:
ModuleConvolutional Dreamer decoder from latent features to image distributions.
Maps concatenated stochastic/deterministic states into transposed-conv outputs and returns a factorized Normal distribution over pixels.
- class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist)[source]
Bases:
ModuleMLP decoder for reward/value/discount prediction from latent features.
The output distribution type is configurable (normal, binary, or raw tensor).
- class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]
Bases:
objectDistribution wrapper that estimates statistics via Monte Carlo sampling.
Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.
- property name
- class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]
Bases:
ModuleDreamer actor head producing squashed continuous actions from latent features.
Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.
JEPA and ViT
- class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]
Bases:
objectGenerate multi-block encoder and predictor masks for JEPA training.
For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.
- class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]
Bases:
objectGenerate random context/prediction patch splits for masked training.
A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.
Diffusion
- class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end, device)[source]
Bases:
objectUtility class implementing forward and reverse DDPM diffusion steps.
Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).
- world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]
Create sinusoidal timestep embeddings for diffusion conditioning.
Embeddings are scaled relative to configured diffusion timesteps and are consumed by the DiT conditioning MLP.
- class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]
Bases:
ModulePatchify an image into a sequence of learnable patch tokens.
A strided convolution performs patch extraction and projects each patch to the transformer hidden dimension with additive positional embeddings.
- class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]
Bases:
ModuleReconstruct image-like tensors from patch-token sequences.
The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.
- class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]
Bases:
ModuleConditioned transformer block used inside the DiT backbone.
Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.
- class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]
Bases:
ModuleDiffusion Transformer model for image denoising and generation.
The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.
- classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]
Datasets and Transforms
- world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]
Create CIFAR-10 dataset objects and a distributed-capable dataloader.
Returns the dataset, sampler, and loader configured with the provided transform/collator so callers can plug the loader directly into JEPA or diffusion training loops.
- world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]
Build an ImageNet-1K dataset and dataloader with distributed sampling support.
This helper optionally restricts data to a subset file and returns the (dataset, dataloader, sampler) tuple used by training scripts.
- class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]
Bases:
ImageFolderImageNet dataset wrapper with optional local copy/extract workflow.
The class extends torchvision.datasets.ImageFolder and can stage data from shared storage into local scratch space for faster multi-process training on cluster environments.
- class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]
Bases:
objectView over an ImageNet dataset filtered by an explicit image-id list.
The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.
- property classes
- world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]
Copy and extract ImageNet archives to per-job local scratch storage.
In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.
- world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]
Create an ImageFolder dataset loader for custom folder-structured datasets.
Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.
- Parameters:
val_split (float | None)
- world_models.transforms.transforms.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]
Compose image augmentations and normalization for vision model training.
Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.
Memory and Controllers
- class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]
Bases:
objectFixed-size replay buffer for Dreamer image observations and transitions.
Stores (observation image, action, reward, terminal) tuples and supports sampling contiguous sequences used by world-model unroll training.
- class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]
Bases:
objectRecords the agent’s interaction with the environment for a single episode. At termination, it converts all the data to Numpy arrays.
- property size
- class world_models.memory.planet_memory.Memory(size=None)[source]
Bases:
dequeEpisode-based replay memory for PlaNet/RSSM training.
Episodes are stored as variable-length trajectories and sampled as sub-sequences with optional time-major formatting for sequence models.
- property size
RSSM-based policy for model-predictive control.
This module provides the RSSMPolicy class that implements model-predictive control using the RSSM (Recurrent State Space Model) latent dynamics model. The policy uses a Cross-Entropy Method (CEM) for planning actions in latent space.
- Reference:
Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111
- class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]
Bases:
objectModel-predictive controller that plans actions with the RSSM latent model.
The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.
- Parameters:
planning_horizon (int)
num_candidates (int)
num_iterations (int)
top_candidates (int)
device (str)
- rssm
The RSSM world model.
- N
Number of candidate action sequences to sample.
- K
Number of top candidates to use for updating the proposal.
- T
Number of CEM iterations per planning step.
- H
Planning horizon (number of future steps to consider).
- d
Action dimensionality.
- device
Device to run computations on.
- state_size
Hidden state dimensionality.
- latent_size
Latent state dimensionality.
Example
>>> policy = RSSMPolicy( ... model=rssm, ... planning_horizon=12, ... num_candidates=1000, ... num_iterations=5, ... top_candidates=100, ... device='cuda' ... ) >>> policy.reset() >>> action = policy.poll(observation)
- reset()[source]
Reset the policy state.
Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.
- poll(observation, explore=False)[source]
Get action for given observation.
- Parameters:
observation (Tensor) – Current observation tensor of shape (channels, height, width).
explore (bool) – If True, add exploration noise to the selected action.
- Returns:
Action tensor of shape (1, action_size).
- Return type:
Tensor
Utilities
- world_models.utils.dreamer_utils.get_parameters(modules)[source]
Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters
- Parameters:
modules (Iterable[Module])
- class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]
Bases:
objectContext manager that temporarily disables gradients for given modules.
Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.
- Parameters:
modules (Iterable[Module])
- class world_models.utils.dreamer_utils.Logger(log_dir, n_logged_samples=10, summary_writer=None)[source]
Bases:
objectExperiment logger for scalars and GIF rollouts using TensorBoardX.
Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.
- world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]
Compute TD(lambda) returns from imagined rewards, values, and discounts.
Implements backward recursion used by Dreamer actor/value objectives.
- world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]
Initialize a tensor in-place from a truncated normal distribution.
Values are sampled from N(mean, std) and clipped to [a, b].
- world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]
Repeat each batch chunk multiple times while preserving chunk ordering.
Used in JEPA masking code to align context and target token batches.
- class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]
Bases:
objectLearning-rate schedule with linear warmup followed by cosine decay.
Updates optimizer parameter-group LRs on each call to step().
- class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]
Bases:
objectCosine scheduler for optimizer weight decay values.
Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.
- world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]
Measure CUDA execution time for a closure and return (result, elapsed_ms).
Falls back to -1 elapsed time when CUDA timing is unavailable.
- class world_models.utils.jepa_utils.CSVLogger(fname, *argv)[source]
Bases:
objectLightweight CSV logger with per-column printf-style formatting.
- class world_models.utils.jepa_utils.AverageMeter[source]
Bases:
objectTrack running statistics (val, avg, min, max, sum, count) for metrics.
- world_models.utils.jepa_utils.grad_logger(named_params)[source]
Aggregate gradient norm statistics over model parameters for logging.
Also exposes first/last qkv-layer gradient norms when available.
- world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]
Initialize torch distributed process groups when environment supports it.
Returns (world_size, rank) and gracefully falls back to single-process mode.
- class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]
Bases:
FunctionAutograd-aware all-gather operation across distributed workers.
Forward concatenates worker tensors; backward reduces and slices gradients.
- class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]
Bases:
FunctionAutograd function that sums tensors across distributed workers in forward pass.