Source code for world_models.utils.jit_utils

import torch
from typing import Callable


[docs] def jit_compile_function(func: Callable) -> Callable: """JIT compile a function for performance.""" try: return torch.jit.script(func) except Exception as e: print(f"JIT compilation failed: {e}") return func
[docs] def jit_compile_module(module: torch.nn.Module) -> torch.nn.Module: """JIT compile a PyTorch module.""" try: return torch.jit.script(module) except Exception as e: print(f"JIT compilation failed: {e}") return module