Using Operators for Inference#
This guide explains how to use TorchWM’s inference operators for standardized input preprocessing.
Overview#
Operators provide a consistent interface for preprocessing inputs before feeding them to world models. Each model has a dedicated operator that handles:
Image normalization and resizing
Action encoding and formatting
Sequence tokenization and padding
Mask generation for self-supervised tasks
Base Operator Class#
All operators inherit from OperatorABC:
from world_models.inference.operators.base import OperatorABC
class MyOperator(OperatorABC):
def process(self, inputs):
# Your preprocessing logic
return processed_tensors
Dreamer Operator#
For Dreamer model’s image and action processing:
from world_models.inference.operators import DreamerOperator
from PIL import Image
import torch
op = DreamerOperator(image_size=64, action_dim=6)
# Process single image and action
image = Image.open('obs.png')
action = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
result = op.process({'image': image, 'action': action})
print(result['obs'].shape) # torch.Size([1, 3, 64, 64])
print(result['action'].shape) # torch.Size([1, 6])
# Process tensor inputs
obs_tensor = torch.randn(3, 64, 64)
result = op.process({'image': obs_tensor, 'action': torch.tensor(action)})
JEPA Operator#
For JEPA’s self-supervised image processing with masking:
from world_models.inference.operators import JEPAOperator
op = JEPAOperator(image_size=224, patch_size=16, mask_ratio=0.75)
# Process batch of images
images = [Image.open(f'img_{i}.jpg') for i in range(4)]
result = op.process({'images': images})
print(result['images'].shape) # torch.Size([4, 3, 224, 224])
print(result['mask'].shape) # torch.Size([196]) - flattened patch mask
# Custom mask
custom_mask = torch.ones(196) # 14x14 patches = 196
result = op.process({'images': images, 'mask': custom_mask})
IRIS Operator#
For IRIS’s sequence processing:
from world_models.inference.operators import IrisOperator
op = IrisOperator(seq_length=512, vocab_size=32000)
# Process token sequence
tokens = [101, 7592, 1010, 2088, 102] # Example tokens
result = op.process({'tokens': tokens})
print(result['input_ids'].shape) # torch.Size([1, 512])
print(result['attention_mask'].shape) # torch.Size([1, 512])
# Process with embeddings
embeddings = torch.randn(1, 768)
result = op.process({'tokens': tokens, 'embeddings': embeddings})
PlaNet Operator#
For PlaNet’s environment state processing:
from world_models.inference.operators import PlaNetOperator
op = PlaNetOperator(state_dim=32, action_dim=4)
# Process transition data
inputs = {
'obs': torch.randn(32), # State vector
'action': [0.1, 0.2, 0.3, 0.4],
'reward': 1.0,
'done': False
}
result = op.process(inputs)
print(result['obs'].shape) # torch.Size([1, 32])
print(result['action'].shape) # torch.Size([1, 4])
print(result['reward'].shape) # torch.Size([1])
print(result['done'].shape) # torch.Size([1])
Configuration Integration#
Operators work seamlessly with config classes:
from world_models.configs import DreamerConfig, JEPAConfig, IRISConfig
# Dreamer
dreamer_cfg = DreamerConfig()
dreamer_op = DreamerOperator(
image_size=dreamer_cfg.operator_image_size,
action_dim=dreamer_cfg.operator_action_dim
)
# JEPA
jepa_cfg = JEPAConfig()
jepa_op = JEPAOperator(
image_size=jepa_cfg.operator_image_size,
patch_size=jepa_cfg.operator_patch_size,
mask_ratio=jepa_cfg.operator_mask_ratio
)
# IRIS
iris_cfg = IRISConfig()
iris_op = IrisOperator(
seq_length=iris_cfg.operator_seq_length,
vocab_size=iris_cfg.operator_vocab_size
)
Utilities#
Common preprocessing functions in world_models.inference.operators.utils:
from world_models.inference.operators.utils import (
normalize_image,
tokenize_text,
resize_image
)
# Image processing
normalized = normalize_image(pil_image, size=(224, 224))
resized = resize_image(tensor_image, size=(64, 64))
# Text processing
tokens = tokenize_text("Hello world", max_length=512)
Error Handling#
Operators validate inputs and provide helpful error messages:
try:
result = op.process(invalid_inputs)
except ValueError as e:
print(f"Preprocessing error: {e}")
Best Practices#
Use with configs: Always instantiate operators using config parameters for consistency.
Batch processing: Operators handle both single inputs and batches.
Device placement: Processed tensors are on CPU; move to GPU as needed.
Type checking: Operators accept various input types (PIL, tensors, lists) and standardize them.
Performance: Operators use optimized transforms for fast preprocessing.
Future: Pipelines#
Operators are designed to integrate with future inference pipelines, providing the preprocessing layer in a standardized Pipeline + Operator + Task architecture.