API Reference

This reference is generated from source docstrings and grouped by subsystem.

Core Public APIs

class world_models.configs.DreamerConfig[source]

Bases: object

Configuration container for Dreamer training, evaluation, and environment setup.

This class centralizes environment backend selection (DMC/Gym/Unity), model dimensions, replay and optimization settings, logging cadence, and checkpoint options consumed by DreamerAgent.

class world_models.configs.JEPAConfig[source]

Bases: object

Minimal configuration container for JEPA training. Converts to the nested dict expected by train_jepa.main.

to_dict()[source]
Return type:

Dict[str, Dict[str, Any]]

class world_models.configs.DiTConfig(DATASET='CIFAR10', BATCH=128, EPOCHS=3, LR=0.0002, IMG_SIZE=32, CHANNELS=3, PATCH=4, WIDTH=384, DEPTH=6, HEADS=6, DROP=0.1, BETA_START=0.0001, BETA_END=0.02, TIMESTEPS=1000, EMA=True, EMA_DECAY=0.999, WORKDIR='./dit_demo', ROOT_PATH='./data')[source]

Bases: object

Default configuration values for Diffusion Transformer (DiT) training.

The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints.

Parameters:
  • DATASET (str)

  • BATCH (int)

  • EPOCHS (int)

  • LR (float)

  • IMG_SIZE (int)

  • CHANNELS (int)

  • PATCH (int)

  • WIDTH (int)

  • DEPTH (int)

  • HEADS (int)

  • DROP (float)

  • BETA_START (float)

  • BETA_END (float)

  • TIMESTEPS (int)

  • EMA (bool)

  • EMA_DECAY (float)

  • WORKDIR (str)

  • ROOT_PATH (str)

DATASET: str = 'CIFAR10'
BATCH: int = 128
EPOCHS: int = 3
LR: float = 0.0002
IMG_SIZE: int = 32
CHANNELS: int = 3
PATCH: int = 4
WIDTH: int = 384
DEPTH: int = 6
HEADS: int = 6
DROP: float = 0.1
BETA_START: float = 0.0001
BETA_END: float = 0.02
TIMESTEPS: int = 1000
EMA: bool = True
EMA_DECAY: float = 0.999
WORKDIR: str = './dit_demo'
ROOT_PATH: str = './data'
world_models.configs.get_dit_config(**overrides)[source]

Returns a DiTConfig instance with default values overridden by the provided keyword arguments.

Example usage:

cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)

Dreamer

class world_models.models.dreamer_rssm.RSSM(action_size, stoch_size, deter_size, hidden_size, obs_embed_size, activation)[source]

Bases: Module

Recurrent State-Space Model used by Dreamer latent dynamics learning.

The RSSM maintains deterministic recurrent state and stochastic latent state, and provides transition/posterior updates plus rollout helpers.

init_state(batch_size, device)[source]
get_dist(mean, std)[source]
observe_step(prev_state, prev_action, obs_embed, nonterm=1.0)[source]
imagine_step(prev_state, prev_action, nonterm=1.0)[source]
observe_rollout(obs_embed, actions, nonterms, prev_state, horizon)[source]
imagine_rollout(actor, prev_state, horizon)[source]
stack_states(states, dim=0)[source]
detach_state(state)[source]
seq_to_batch(state)[source]
class world_models.vision.dreamer_encoder.ConvEncoder(input_shape, embed_size, activation, depth=32)[source]

Bases: Module

Convolutional observation encoder used by Dreamer world models.

Encodes image observations into compact embeddings consumed by the RSSM posterior update network.

forward(inputs)[source]
class world_models.vision.dreamer_decoder.TanhBijector[source]

Bases: Transform

Bijective tanh transform used to squash Gaussian actions to [-1, 1].

Provides stable inverse and log-determinant Jacobian computations for transformed-action distributions.

property sign
atanh(x)[source]
log_abs_det_jacobian(x, y)[source]
class world_models.vision.dreamer_decoder.ConvDecoder(stoch_size, deter_size, output_shape, activation, depth=32)[source]

Bases: Module

Convolutional Dreamer decoder from latent features to image distributions.

Maps concatenated stochastic/deterministic states into transposed-conv outputs and returns a factorized Normal distribution over pixels.

forward(features)[source]
class world_models.vision.dreamer_decoder.DenseDecoder(stoch_size, deter_size, output_shape, n_layers, units, activation, dist)[source]

Bases: Module

MLP decoder for reward/value/discount prediction from latent features.

The output distribution type is configurable (normal, binary, or raw tensor).

forward(features)[source]
class world_models.vision.dreamer_decoder.SampleDist(dist, samples=100)[source]

Bases: object

Distribution wrapper that estimates statistics via Monte Carlo sampling.

Provides approximated mean, mode, and entropy helpers for transformed distributions where analytic forms may be inconvenient.

property name
mean()[source]
mode()[source]
entropy()[source]
sample()[source]
class world_models.vision.dreamer_decoder.ActionDecoder(action_size, stoch_size, deter_size, n_layers, units, activation, min_std=0.0001, init_std=5, mean_scale=5)[source]

Bases: Module

Dreamer actor head producing squashed continuous actions from latent features.

Outputs a transformed Gaussian policy with optional deterministic mode and utility for additive exploration noise.

forward(features, deter=False)[source]
add_exploration(action, action_noise=0.3)[source]

JEPA and ViT

class world_models.masks.multiblock.MaskCollator(input_size=(224, 224), patch_size=16, enc_mask_scale=(0.2, 0.8), pred_mask_scale=(0.5, 1.0), aspect_ratio=(0.3, 3.0), nenc=1, npred=2, min_keep=4, allow_overlap=False)[source]

Bases: object

Generate multi-block encoder and predictor masks for JEPA training.

For each sample, this collator samples predictor target blocks and context encoder blocks (optionally non-overlapping), then returns masked patch indices aligned across the batch.

step()[source]
class world_models.masks.random.MaskCollator(ratio=(0.4, 0.6), input_size=(224, 224), patch_size=16)[source]

Bases: object

Generate random context/prediction patch splits for masked training.

A random permutation of patch indices is sampled per image; a configurable fraction is assigned to context and the remainder to prediction targets.

step()[source]

Diffusion

class world_models.models.diffusion.DDPM.DDPM(timesteps, beta_start, beta_end, device)[source]

Bases: object

Utility class implementing forward and reverse DDPM diffusion steps.

Precomputes diffusion schedule terms and exposes helpers for noising training inputs (q_sample) and iterative denoising sampling (sample).

q_sample(x_start, t, noise=None)[source]
p_sample(model, x_t, t)[source]
sample(model, n, img_size, channels)[source]
world_models.models.diffusion.DiT.sinusoidal_time_embedding(timesteps, dim)[source]

Create sinusoidal timestep embeddings for diffusion conditioning.

Embeddings are scaled relative to configured diffusion timesteps and are consumed by the DiT conditioning MLP.

class world_models.models.diffusion.DiT.PatchEmbed(img_size, patch_size, in_channels, embed_dim)[source]

Bases: Module

Patchify an image into a sequence of learnable patch tokens.

A strided convolution performs patch extraction and projects each patch to the transformer hidden dimension with additive positional embeddings.

forward(x)[source]
class world_models.models.diffusion.DiT.PatchUnEmbed(img_size, patch_size, embed_dim, out_channels)[source]

Bases: Module

Reconstruct image-like tensors from patch-token sequences.

The inverse of PatchEmbed, this module reshapes token sequences into grids and uses transposed convolution to decode spatial outputs.

forward(x)[source]
class world_models.models.diffusion.DiT.TransformerBlock(d_model, n_heads, mlp_ratio, drop, t_dim)[source]

Bases: Module

Conditioned transformer block used inside the DiT backbone.

Each block applies adaptive layer-normalized self-attention and MLP residual updates conditioned on timestep embeddings.

forward(x, t_emb)[source]
class world_models.models.diffusion.DiT.DiT(img_size, patch_size, in_channels, d_model, depth, heads, drop=0.0, t_dim=256)[source]

Bases: Module

Diffusion Transformer model for image denoising and generation.

The module maps noisy images and timesteps to predicted noise residuals and also provides a classmethod training entrypoint for common datasets.

forward(x, t)[source]
classmethod train(epochs, dataset, batch_size=128, lr=0.0002, img_size=32, channels=3, patch=4, width=384, depth=6, heads=6, drop=0.1, timesteps=1000, beta_start=0.0001, beta_end=0.02, ema=True, ema_decay=0.999, workdir='./dit_demo', root_path='./data', image_folder=None, crop_size=224, download=True, copy_data=False, subset_file=None, val_split=None)[source]

Datasets and Transforms

world_models.datasets.cifar10.make_cifar10(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, drop_last=True, train=True, download=False)[source]

Create CIFAR-10 dataset objects and a distributed-capable dataloader.

Returns the dataset, sampler, and loader configured with the provided transform/collator so callers can plug the loader directly into JEPA or diffusion training loops.

world_models.datasets.imagenet1k.make_imagenet1k(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, training=True, copy_data=False, drop_last=True, subset_file=None)[source]

Build an ImageNet-1K dataset and dataloader with distributed sampling support.

This helper optionally restricts data to a subset file and returns the (dataset, dataloader, sampler) tuple used by training scripts.

class world_models.datasets.imagenet1k.ImageNet(root, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', transform=None, train=True, job_id=None, local_rank=None, copy_data=True, index_targets=False)[source]

Bases: ImageFolder

ImageNet dataset wrapper with optional local copy/extract workflow.

The class extends torchvision.datasets.ImageFolder and can stage data from shared storage into local scratch space for faster multi-process training on cluster environments.

class world_models.datasets.imagenet1k.ImageNetSubset(dataset, subset_file)[source]

Bases: object

View over an ImageNet dataset filtered by an explicit image-id list.

The subset file contains target image names; only matching samples are kept while preserving transforms and label mapping from the base dataset.

filter_dataset_(subset_file)[source]

Filter self.dataset to a subset

property classes
world_models.datasets.imagenet1k.copy_imgnt_locally(root, suffix, image_folder='imagenet_full_size/061417/', tar_file='imagenet_full_size-061417.tar.gz', job_id=None, local_rank=None)[source]

Copy and extract ImageNet archives to per-job local scratch storage.

In SLURM environments this reduces network filesystem pressure by unpacking once per job and synchronizing worker processes with a signal file.

world_models.datasets.imagenet1k.make_imagefolder(transform, batch_size, collator=None, pin_mem=True, num_workers=8, world_size=1, rank=0, root_path=None, image_folder=None, drop_last=True, val_split=None)[source]

Create an ImageFolder dataset loader for custom folder-structured datasets.

Supports optional train/validation split and distributed sampling, making it a drop-in replacement for ImageNet loaders in training scripts.

Parameters:

val_split (float | None)

world_models.transforms.transforms.make_transforms(crop_size=224, crop_scale=(0.3, 1.0), color_jitter=1.0, horizontal_flip=False, color_distortion=False, gaussian_blur=False, normalization=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))[source]

Compose image augmentations and normalization for vision model training.

Supports random crops, optional flip/color distortion/blur, and returns a torchvision.transforms.Compose pipeline.

class world_models.transforms.transforms.GaussianBlur(p=0.5, radius_min=0.1, radius_max=2.0)[source]

Bases: object

Probabilistic Gaussian blur augmentation for PIL images.

Applies blur with random radius in a configurable range when sampled.

Memory and Controllers

class world_models.memory.dreamer_memory.ReplayBuffer(size, obs_shape, action_size, seq_len, batch_size)[source]

Bases: object

Fixed-size replay buffer for Dreamer image observations and transitions.

Stores (observation image, action, reward, terminal) tuples and supports sampling contiguous sequences used by world-model unroll training.

add(obs, ac, rew, done)[source]
sample()[source]
class world_models.memory.planet_memory.Episode(postprocess_fn=None)[source]

Bases: object

Records the agent’s interaction with the environment for a single episode. At termination, it converts all the data to Numpy arrays.

property size
append(obs, act, reward, terminal)[source]
terminate(obs)[source]
class world_models.memory.planet_memory.Memory(size=None)[source]

Bases: deque

Episode-based replay memory for PlaNet/RSSM training.

Episodes are stored as variable-length trajectories and sampled as sub-sequences with optional time-major formatting for sequence models.

property size
append(episodes)[source]
Parameters:

episodes (list[Episode])

sample(batch_size, tracelen=1, time_first=False)[source]

RSSM-based policy for model-predictive control.

This module provides the RSSMPolicy class that implements model-predictive control using the RSSM (Recurrent State Space Model) latent dynamics model. The policy uses a Cross-Entropy Method (CEM) for planning actions in latent space.

Reference:

Ha & Schmidhuber (2018). Recurrent World Models Facilitate Policy Evolution. https://arxiv.org/abs/1805.11111

class world_models.controller.rssm_policy.RSSMPolicy(model, planning_horizon, num_candidates, num_iterations, top_candidates, device)[source]

Bases: object

Model-predictive controller that plans actions with the RSSM latent model.

The policy uses a Cross-Entropy Method style loop: it samples candidate action sequences, rolls them forward in latent space, scores predicted returns, and refits a Gaussian proposal to top-performing candidates.

Parameters:
  • planning_horizon (int)

  • num_candidates (int)

  • num_iterations (int)

  • top_candidates (int)

  • device (str)

rssm

The RSSM world model.

N

Number of candidate action sequences to sample.

K

Number of top candidates to use for updating the proposal.

T

Number of CEM iterations per planning step.

H

Planning horizon (number of future steps to consider).

d

Action dimensionality.

device

Device to run computations on.

state_size

Hidden state dimensionality.

latent_size

Latent state dimensionality.

Example

>>> policy = RSSMPolicy(
...     model=rssm,
...     planning_horizon=12,
...     num_candidates=1000,
...     num_iterations=5,
...     top_candidates=100,
...     device='cuda'
... )
>>> policy.reset()
>>> action = policy.poll(observation)
reset()[source]

Reset the policy state.

Initializes the hidden state, latent state, and action to zeros. Should be called at the beginning of each episode.

poll(observation, explore=False)[source]

Get action for given observation.

Parameters:
  • observation (Tensor) – Current observation tensor of shape (channels, height, width).

  • explore (bool) – If True, add exploration noise to the selected action.

Returns:

Action tensor of shape (1, action_size).

Return type:

Tensor

Utilities

world_models.utils.dreamer_utils.get_parameters(modules)[source]

Given a list of torch modules, returns a list of their parameters. :param modules: iterable of modules :returns: a list of parameters

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.FreezeParameters(modules)[source]

Bases: object

Context manager that temporarily disables gradients for given modules.

Useful during imagination or target-network forward passes where gradients through certain components should be blocked for speed and correctness.

Parameters:

modules (Iterable[Module])

class world_models.utils.dreamer_utils.Logger(log_dir, n_logged_samples=10, summary_writer=None)[source]

Bases: object

Experiment logger for scalars and GIF rollouts using TensorBoardX.

Provides helpers to write scalar metrics, dump pickle snapshots, and save video previews during Dreamer training/evaluation.

log_scalar(scalar, name, step_)[source]
log_scalars(scalar_dict, step)[source]
log_videos(videos, step, max_videos_to_save=1, fps=20, video_title='video')[source]
dump_scalars_to_pickle(metrics, step, log_title=None)[source]
flush()[source]
world_models.utils.dreamer_utils.compute_return(rewards, values, discounts, td_lam, last_value)[source]

Compute TD(lambda) returns from imagined rewards, values, and discounts.

Implements backward recursion used by Dreamer actor/value objectives.

world_models.utils.jepa_utils.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2, b=2.0)[source]

Initialize a tensor in-place from a truncated normal distribution.

Values are sampled from N(mean, std) and clipped to [a, b].

world_models.utils.jepa_utils.repeat_interleave_batch(x, B, repeat)[source]

Repeat each batch chunk multiple times while preserving chunk ordering.

Used in JEPA masking code to align context and target token batches.

class world_models.utils.jepa_utils.WarmupCosineSchedule(optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0)[source]

Bases: object

Learning-rate schedule with linear warmup followed by cosine decay.

Updates optimizer parameter-group LRs on each call to step().

step()[source]
class world_models.utils.jepa_utils.CosineWDSchedule(optimizer, ref_wd, T_max, final_wd=0.0)[source]

Bases: object

Cosine scheduler for optimizer weight decay values.

Skips parameter groups flagged with WD_exclude to keep bias/norm decay at zero.

step()[source]
world_models.utils.jepa_utils.gpu_timer(closure, log_timings=True)[source]

Measure CUDA execution time for a closure and return (result, elapsed_ms).

Falls back to -1 elapsed time when CUDA timing is unavailable.

class world_models.utils.jepa_utils.CSVLogger(fname, *argv)[source]

Bases: object

Lightweight CSV logger with per-column printf-style formatting.

log(*argv)[source]
class world_models.utils.jepa_utils.AverageMeter[source]

Bases: object

Track running statistics (val, avg, min, max, sum, count) for metrics.

reset()[source]
update(val, n=1)[source]
world_models.utils.jepa_utils.grad_logger(named_params)[source]

Aggregate gradient norm statistics over model parameters for logging.

Also exposes first/last qkv-layer gradient norms when available.

world_models.utils.jepa_utils.init_distributed(port=40112, rank_and_world_size=(None, None))[source]

Initialize torch distributed process groups when environment supports it.

Returns (world_size, rank) and gracefully falls back to single-process mode.

class world_models.utils.jepa_utils.AllGather(*args, **kwargs)[source]

Bases: Function

Autograd-aware all-gather operation across distributed workers.

Forward concatenates worker tensors; backward reduces and slices gradients.

static forward(ctx, x)[source]
static backward(ctx, grads)[source]
class world_models.utils.jepa_utils.AllReduceSum(*args, **kwargs)[source]

Bases: Function

Autograd function that sums tensors across distributed workers in forward pass.

static forward(ctx, x)[source]
static backward(ctx, grads)[source]
class world_models.utils.jepa_utils.AllReduce(*args, **kwargs)[source]

Bases: Function

Autograd function that all-reduces and averages tensors across workers.

Used to synchronize scalar losses for consistent distributed logging/training.

static forward(ctx, x)[source]
static backward(ctx, grads)[source]