Source code for world_models.inference.operators.base

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn


[docs] @dataclass(frozen=True) class TensorSpec: """Optional tensor contract used to validate operator inputs or outputs. Args: shape: Expected shape. Use ``None`` as a wildcard for dimensions that may vary, such as batch size. dtype: Expected tensor dtype. required: Whether the key must be present in the mapping being validated. """ shape: tuple[int | None, ...] | None = None dtype: torch.dtype | None = None required: bool = True
[docs] class OperatorABC(nn.Module, ABC): """Structured base class for inference operators. Operators use a consistent pipeline: 1. ``preprocess`` converts raw inputs into tensors. 2. ``forward`` performs model/operator-specific tensor computation. 3. ``postprocess`` formats the final output mapping. Subclasses may also declare ``input_specs`` and ``output_specs`` to validate required tensor keys, shapes, and dtypes. ``OperatorABC`` inherits from ``torch.nn.Module``, so operators support ``to(device)``, ``train()``, and ``eval()`` just like model modules. """ input_specs: Mapping[str, TensorSpec] = {} output_specs: Mapping[str, TensorSpec] = {} def __init__(self, *, device: torch.device | str | None = None) -> None: super().__init__() self.device = ( torch.device(device) if device is not None else torch.device("cpu") )
[docs] @abstractmethod def preprocess(self, inputs: Any) -> dict[str, torch.Tensor]: """Convert raw inputs into a tensor mapping ready for ``forward``."""
[docs] def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Run tensor computation for this operator. Preprocessing-only operators can rely on this identity implementation. Operators that wrap a model should override this method. """ return inputs
[docs] def postprocess(self, outputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Format validated forward outputs for consumers.""" return outputs
[docs] def process(self, inputs: Any) -> dict[str, torch.Tensor]: """Process raw inputs through preprocess, forward, and postprocess stages.""" preprocessed = self.preprocess(inputs) self.validate_mapping( preprocessed, self.input_specs, label="preprocessed input" ) preprocessed = self._move_tensors(preprocessed) outputs = self.forward(preprocessed) self.validate_mapping(outputs, self.output_specs, label="operator output") return self.postprocess(outputs)
[docs] def batch(self, inputs: Sequence[Any]) -> dict[str, torch.Tensor]: """Preprocess a sequence of inputs and stack matching tensor keys.""" if not inputs: raise ValueError("Cannot batch an empty input sequence") processed = [self.process(item) for item in inputs] keys = processed[0].keys() for index, item in enumerate(processed[1:], start=1): if item.keys() != keys: raise ValueError( f"Batched operator outputs must share keys; item 0 has {sorted(keys)} " f"but item {index} has {sorted(item.keys())}" ) return {key: torch.stack([item[key] for item in processed]) for key in keys}
[docs] def to(self, *args: Any, **kwargs: Any) -> "OperatorABC": """Move module parameters/buffers and remember the target tensor device.""" module = super().to(*args, **kwargs) device = self._device_from_to_args(*args, **kwargs) if device is not None: self.device = device return module
def __call__(self, inputs: Any) -> dict[str, torch.Tensor]: return self.process(inputs)
[docs] @classmethod def validate_mapping( cls, values: Mapping[str, torch.Tensor], specs: Mapping[str, TensorSpec], *, label: str, ) -> None: """Validate tensor keys, shapes, and dtypes against optional specs.""" if not isinstance(values, Mapping): raise TypeError(f"{label} must be a mapping of tensor names to tensors") for key, spec in specs.items(): if key not in values: if spec.required: raise ValueError(f"Missing required {label} key: {key!r}") continue value = values[key] if not isinstance(value, torch.Tensor): raise TypeError(f"{label} key {key!r} must be a torch.Tensor") if spec.dtype is not None and value.dtype != spec.dtype: raise TypeError( f"{label} key {key!r} must have dtype {spec.dtype}, got {value.dtype}" ) if spec.shape is not None: cls._validate_shape(key, value, spec.shape, label=label)
@staticmethod def _validate_shape( key: str, value: torch.Tensor, expected: tuple[int | None, ...], *, label: str, ) -> None: if value.dim() != len(expected): raise ValueError( f"{label} key {key!r} must have {len(expected)} dims, got {value.dim()}" ) for dim_index, (actual, expected_dim) in enumerate(zip(value.shape, expected)): if expected_dim is not None and actual != expected_dim: raise ValueError( f"{label} key {key!r} dim {dim_index} must be {expected_dim}, got {actual}" ) def _move_tensors(self, values: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: return {key: value.to(self.device) for key, value in values.items()} @staticmethod def _device_from_to_args(*args: Any, **kwargs: Any) -> torch.device | None: if "device" in kwargs and kwargs["device"] is not None: return torch.device(kwargs["device"]) for arg in args: if isinstance(arg, (torch.device, str, int)): try: return torch.device(arg) except (TypeError, RuntimeError): continue if isinstance(arg, torch.Tensor): return arg.device return None