Getting Started#
Installation#
Install from PyPI:
pip install torchwm
Install from source:
git clone https://github.com/ParamThakkar123/torchwm.git
cd torchwm
pip install -e .
For development and tests:
pip install -e ".[dev]"
Logging with Weights & Biases and TensorBoard#
TorchWM supports logging experiment results to Weights & Biases (WandB) and TensorBoard.
Weights & Biases#
To use WandB logging, you must provide an API key as anonymous logins are no longer supported.
Get your WandB API key from wandb.ai.
Set the key in your config:
cfg.enable_wandb = True
cfg.wandb_api_key = "your-api-key-here"
cfg.wandb_project = "torchwm"
cfg.wandb_entity = "your-entity"
TensorBoard#
Enable TensorBoard logging:
cfg.enable_tensorboard = True
cfg.log_dir = "runs"
Logs will be saved to the specified directory and can be viewed with tensorboard --logdir runs.
Quick Start: Friendly API#
The recommended entrypoint for common workflows is torchwm. It mirrors the
TorchWM implementation package, but gives users short factory helpers for
discovery, model creation, environment creation, and operators.
import torchwm
print(torchwm.list_models())
print(torchwm.list_env_backends())
agent = torchwm.create_model("dreamer", env="walker-walk", total_steps=1_000_000)
env = torchwm.make_env("CartPole-v1", backend="gym")
op = torchwm.get_operator("dreamer", image_size=64, action_dim=6)
You can still import direct research components from torchwm when you
need lower-level control:
from torchwm import DreamerAgent, DreamerConfig
cfg = DreamerConfig()
cfg.env = "walker-walk"
agent = DreamerAgent(cfg)
Quick Start: Dreamer#
TorchWM implements multiple world model algorithms. Click on each to see detailed documentation:
Algorithm |
Description |
Quick Start |
|---|---|---|
Dreamer |
Model-based RL with latent dynamics |
|
JEPA |
Self-supervised visual representations |
|
IRIS |
Sample-efficient RL with Transformers |
|
DiT |
Diffusion models with Transformers |
Quick Start: Inference with Operators#
TorchWM now includes standardized operators for preprocessing inputs during inference, making it easy to deploy models consistently.
What are Operators?#
Operators handle input preprocessing: normalizing images, encoding actions, tokenizing text, and generating masks. Each model has a dedicated operator that ensures inputs are in the correct format.
Basic Usage#
import torchwm
# Create operator with config parameters
op = torchwm.get_operator(
"dreamer",
image_size=64, # Image size for Dreamer
action_dim=6 # Action dimension
)
# Process raw inputs
raw_inputs = {
'image': your_pil_image_or_tensor, # PIL Image or torch.Tensor
'action': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] # Action as list
}
# Get processed tensors
processed = op.process(raw_inputs)
# Returns: {'obs': tensor(B, 3, 64, 64), 'action': tensor(B, 6)}
Available Operators#
Operator |
Model |
Purpose |
Key Parameters |
|---|---|---|---|
|
Dreamer |
Image/action preprocessing |
|
|
JEPA |
Image masking and patching |
|
|
IRIS |
Sequence tokenization |
|
|
PlaNet |
State/action transitions |
|
JEPA Example (Self-Supervised)#
from torchwm import JEPAOperator
op = JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)
inputs = {'images': [image1, image2]}
result = op(inputs)
# result['images']: stacked normalized images
# result['mask']: random mask for self-supervised learning
IRIS Example (Sequence Processing)#
from torchwm import IrisOperator
op = IrisOperator(seq_length=512, vocab_size=32000)
inputs = {'tokens': [101, 2054, 2003, 102]} # Token sequence
result = op(inputs)
# result['input_ids']: padded token tensor
# result['attention_mask']: attention mask
Integration with Configs#
Operators can reuse matching config fields, but operator-only parameters such as action dimensions should be supplied from the target environment:
import torchwm
cfg = torchwm.create_config("dreamer")
op = torchwm.get_operator(
"dreamer",
image_size=cfg.image_size[0],
action_dim=6,
)
Utilities#
For most applications, use torchwm.get_operator() for preprocessing. Advanced utility functions remain available to package internals.
import torchwm
op = torchwm.get_operator("jepa", image_size=224, patch_size=16, mask_ratio=0.75)
processed = op.process({"images": [pil_image]})
Train a complete world model pipeline (VAE + MDNRNN + Controller) on any Gym environment:
# Train on CarRacing
python -m world_models.training.train_world_model --env CarRacing-v2
# Train on Pendulum
python -m world_models.training.train_world_model --env Pendulum-v1
# Test trained model
python -m world_models.training.train_world_model --env CarRacing-v2 --test
# Specify action size manually for environments with missing dependencies
python -m world_models.training.train_world_model --env BipedalWalker-v3 --action_size 4
Dreamer supports multiple backends through DreamerConfig.env_backend; the
top-level torchwm.make_env() helper uses the same backend names for standalone
environment creation:
dmc: DeepMind Control Suite tasks (for examplewalker-walk)dmlab: DeepMind Lab 3D navigation tasks (for examplerooms_collect_good_objects_train)gym: Gym/Gymnasium environment IDs or an existing environment instancemujoco: Gymnasium MuJoCo task ids or native MJCF/MJB modelsrobotics: any id registered by the installed Gymnasium Robotics packageprocgen: Procgen benchmark games such ascoinrunandheistbrax: JAX/Brax continuous-control environmentsunity_mlagents: Unity ML-Agents executable environments
Typical Training Flow#
Choose an algorithm (Dreamer, JEPA, IRIS, or DiT)
Create a config object for that algorithm
Override dataset/environment and optimization fields
Instantiate the corresponding agent
Call
train()and monitor logs/checkpoints
For complete API details, see API Reference.