Source code for world_models.inference.operators

from .base import OperatorABC
from .dreamer_operator import DreamerOperator
from .jepa_operator import JEPAOperator
from .iris_operator import IrisOperator
from .planet_operator import PlaNetOperator


[docs] def get_operator(name: str, **kwargs): """Factory function to get inference operators by name. Args: name: One of 'dreamer', 'jepa', 'iris', 'planet' **kwargs: Operator-specific configuration Returns: Configured OperatorABC instance Example: >>> op = get_operator('dreamer', image_size=64, action_dim=6) >>> processed = op.process({'image': image, 'action': action}) """ operators = { "dreamer": DreamerOperator, "jepa": JEPAOperator, "iris": IrisOperator, "planet": PlaNetOperator, } if name.lower() not in operators: raise ValueError( f"Unknown operator {name!r}. Available: {list(operators.keys())}" ) return operators[name.lower()](**kwargs)
__all__ = [ "OperatorABC", "DreamerOperator", "JEPAOperator", "IrisOperator", "PlaNetOperator", "get_operator", ]