Inference Guide#
This guide covers how to use trained TorchWM models for inference and deployment.
Overview#
TorchWM provides standardized inference through operators and future pipelines.
Loading Trained Models#
from world_models.models import DreamerAgent
# Load from checkpoint
agent = DreamerAgent.from_pretrained("path/to/checkpoint")
agent.eval()
Using Operators for Preprocessing#
See Using Operators for Inference for detailed operator usage.
Basic Inference#
Dreamer#
import torch
from world_models.inference.operators import DreamerOperator
op = DreamerOperator()
agent = DreamerAgent.from_pretrained("dreamer_checkpoint")
# Single step prediction
obs = torch.randn(3, 64, 64)
action = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
with torch.no_grad():
processed = op({'image': obs, 'action': action})
next_obs, reward = agent.predict(processed)
JEPA#
from world_models.models import JEPAAgent
from world_models.inference.operators import JEPAOperator
op = JEPAOperator()
agent = JEPAAgent.from_pretrained("jepa_checkpoint")
# Representation learning
images = [torch.randn(3, 224, 224) for _ in range(8)]
processed = op({'images': images})
with torch.no_grad():
representations = agent.encode(processed)
Rollout and Imagination#
Generate imagined trajectories:
# Dreamer imagination
from world_models.models import DreamerAgent
agent = DreamerAgent.from_pretrained("dreamer_checkpoint")
initial_obs = torch.randn(3, 64, 64)
horizon = 10
imagined_trajectory = agent.imagine_rollout(initial_obs, horizon)
# Returns dict with imagined observations, actions, rewards
Batch Inference#
Process multiple inputs efficiently:
batch_size = 32
obs_batch = torch.randn(batch_size, 3, 64, 64)
action_batch = torch.randn(batch_size, 6)
processed = op({'image': obs_batch, 'action': action_batch})
with torch.no_grad():
predictions = agent.predict_batch(processed)
GPU Acceleration#
Move to GPU for faster inference:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = agent.to(device)
processed = {k: v.to(device) for k, v in processed.items()}
with torch.no_grad():
output = agent.predict(processed)
Real-time Inference#
For interactive applications:
class InferenceServer:
def __init__(self):
self.agent = DreamerAgent.from_pretrained("checkpoint").eval()
self.op = DreamerOperator()
def predict(self, obs, action):
processed = self.op({'image': obs, 'action': action})
with torch.no_grad():
return self.agent.predict(processed)
server = InferenceServer()
Performance Optimization#
JIT Compilation#
from world_models.utils.jit_utils import jit_compile_module
agent = jit_compile_module(agent)
Memory Efficient Inference#
from world_models.utils.memory_utils import optimize_memory_efficient_ops
optimize_memory_efficient_ops()
Exporting Models#
Export to ONNX or TorchScript:
# TorchScript
scripted = torch.jit.script(agent)
torch.jit.save(scripted, "model.pt")
# ONNX
dummy_input = op({'image': torch.randn(1, 3, 64, 64), 'action': torch.randn(1, 6)})
torch.onnx.export(agent, dummy_input, "model.onnx")
Integration Examples#
With Gym Environments#
import gymnasium as gym
env = gym.make("Pendulum-v1")
agent = DreamerAgent.from_pretrained("pendulum_checkpoint")
op = DreamerOperator()
obs, _ = env.reset()
done = False
while not done:
action = agent.act(obs) # Get action from agent
obs, reward, done, _, _ = env.step(action)
With Custom Environments#
class CustomEnv:
def step(self, action):
# Your environment logic
return obs, reward, done
env = CustomEnv()
agent = DreamerAgent.from_pretrained("custom_checkpoint")
for episode in range(10):
obs = env.reset()
total_reward = 0
while True:
processed = op({'image': obs, 'action': action})
with torch.no_grad():
next_obs_pred, reward_pred = agent.predict(processed)
# Use predictions for planning/control
action = agent.plan(obs, next_obs_pred, reward_pred)
obs, reward, done = env.step(action)
total_reward += reward
if done:
break
print(f"Episode {episode}: {total_reward}")
Troubleshooting#
Memory Issues#
Use smaller batch sizes
Enable gradient checkpointing
Clear cache:
torch.cuda.empty_cache()
Speed Issues#
Move to GPU
Use JIT compilation
Batch inputs when possible
Accuracy Issues#
Ensure proper preprocessing with operators
Check model loading
Verify input shapes match training