Inference Guide#

This guide covers how to use trained TorchWM models for inference and deployment.

Overview#

TorchWM provides standardized inference through operators and future pipelines. For application code, prefer the top-level torchwm.get_operator() factory; it keeps examples short and avoids deep imports.

Loading Trained Models#

from torchwm import DreamerAgent

# Load from checkpoint
agent = DreamerAgent.from_pretrained("path/to/checkpoint")
agent.eval()

Using Operators for Preprocessing#

See Using Operators for Inference for detailed operator usage.

Basic Inference#

Dreamer#

import torch
import torchwm
from torchwm import DreamerAgent

op = torchwm.get_operator("dreamer", image_size=64, action_dim=6)
agent = DreamerAgent.from_pretrained("dreamer_checkpoint")

# Single step prediction
obs = torch.randn(3, 64, 64)
action = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])

with torch.no_grad():
    processed = op({'image': obs, 'action': action})
    next_obs, reward = agent.predict(processed)

JEPA#

import torch
import torchwm
from torchwm import JEPAAgent

op = torchwm.get_operator("jepa", image_size=224, patch_size=16, mask_ratio=0.75)
agent = JEPAAgent.from_pretrained("jepa_checkpoint")

# Representation learning
images = [torch.randn(3, 224, 224) for _ in range(8)]
processed = op({'images': images})

with torch.no_grad():
    representations = agent.encode(processed)

Rollout and Imagination#

Generate imagined trajectories:

# Dreamer imagination
from torchwm import DreamerAgent

agent = DreamerAgent.from_pretrained("dreamer_checkpoint")

initial_obs = torch.randn(3, 64, 64)
horizon = 10

imagined_trajectory = agent.imagine_rollout(initial_obs, horizon)
# Returns dict with imagined observations, actions, rewards

Batch Inference#

Process multiple inputs efficiently:

batch_size = 32
obs_batch = torch.randn(batch_size, 3, 64, 64)
action_batch = torch.randn(batch_size, 6)

processed = op({'image': obs_batch, 'action': action_batch})

with torch.no_grad():
    predictions = agent.predict_batch(processed)

GPU Acceleration#

Move to GPU for faster inference:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

agent = agent.to(device)
processed = {k: v.to(device) for k, v in processed.items()}

with torch.no_grad():
    output = agent.predict(processed)

Real-time Inference#

For interactive applications:

import torch
import torchwm
from torchwm import DreamerAgent

class InferenceServer:
    def __init__(self):
        self.agent = DreamerAgent.from_pretrained("checkpoint").eval()
        self.op = torchwm.get_operator("dreamer", image_size=64, action_dim=6)

    def predict(self, obs, action):
        processed = self.op({'image': obs, 'action': action})
        with torch.no_grad():
            return self.agent.predict(processed)

server = InferenceServer()

Performance Optimization#

JIT Compilation#

import torch

agent = torch.jit.script(agent)

Memory Efficient Inference#

import torch

with torch.inference_mode():
    output = agent.predict(processed)

Exporting Models#

Export to ONNX or TorchScript:

# TorchScript
scripted = torch.jit.script(agent)
torch.jit.save(scripted, "model.pt")

# ONNX
dummy_input = op({'image': torch.randn(1, 3, 64, 64), 'action': torch.randn(1, 6)})
torch.onnx.export(agent, dummy_input, "model.onnx")

Integration Examples#

With Gym Environments#

import torchwm
from torchwm import DreamerAgent

env = torchwm.make_env("Pendulum-v1", backend="gym")
agent = DreamerAgent.from_pretrained("pendulum_checkpoint")
op = torchwm.get_operator("dreamer", image_size=64, action_dim=env.action_space.shape[0])

obs, _ = env.reset()
done = False

while not done:
    action = agent.act(obs)  # Get action from agent
    obs, reward, done, _, _ = env.step(action)

With Custom Environments#

class CustomEnv:
    def step(self, action):
        # Your environment logic
        return obs, reward, done

env = CustomEnv()
agent = DreamerAgent.from_pretrained("custom_checkpoint")

for episode in range(10):
    obs = env.reset()
    total_reward = 0

    while True:
        processed = op({'image': obs, 'action': action})
        with torch.no_grad():
            next_obs_pred, reward_pred = agent.predict(processed)

        # Use predictions for planning/control
        action = agent.plan(obs, next_obs_pred, reward_pred)
        obs, reward, done = env.step(action)
        total_reward += reward

        if done:
            break

    print(f"Episode {episode}: {total_reward}")

Troubleshooting#

Memory Issues#

  • Use smaller batch sizes

  • Enable gradient checkpointing

  • Clear cache: torch.cuda.empty_cache()

Speed Issues#

  • Move to GPU

  • Use JIT compilation

  • Batch inputs when possible

Accuracy Issues#

  • Ensure proper preprocessing with operators

  • Check model loading

  • Verify input shapes match training