Inference Guide#
This guide covers how to use trained TorchWM models for inference and deployment.
Overview#
TorchWM provides standardized inference through operators and future pipelines.
For application code, prefer the top-level torchwm.get_operator() factory; it
keeps examples short and avoids deep imports.
Loading Trained Models#
from torchwm import DreamerAgent
# Load from checkpoint
agent = DreamerAgent.from_pretrained("path/to/checkpoint")
agent.eval()
Using Operators for Preprocessing#
See Using Operators for Inference for detailed operator usage.
Basic Inference#
Dreamer#
import torch
import torchwm
from torchwm import DreamerAgent
op = torchwm.get_operator("dreamer", image_size=64, action_dim=6)
agent = DreamerAgent.from_pretrained("dreamer_checkpoint")
# Single step prediction
obs = torch.randn(3, 64, 64)
action = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
with torch.no_grad():
processed = op({'image': obs, 'action': action})
next_obs, reward = agent.predict(processed)
JEPA#
import torch
import torchwm
from torchwm import JEPAAgent
op = torchwm.get_operator("jepa", image_size=224, patch_size=16, mask_ratio=0.75)
agent = JEPAAgent.from_pretrained("jepa_checkpoint")
# Representation learning
images = [torch.randn(3, 224, 224) for _ in range(8)]
processed = op({'images': images})
with torch.no_grad():
representations = agent.encode(processed)
Rollout and Imagination#
Generate imagined trajectories:
# Dreamer imagination
from torchwm import DreamerAgent
agent = DreamerAgent.from_pretrained("dreamer_checkpoint")
initial_obs = torch.randn(3, 64, 64)
horizon = 10
imagined_trajectory = agent.imagine_rollout(initial_obs, horizon)
# Returns dict with imagined observations, actions, rewards
Batch Inference#
Process multiple inputs efficiently:
batch_size = 32
obs_batch = torch.randn(batch_size, 3, 64, 64)
action_batch = torch.randn(batch_size, 6)
processed = op({'image': obs_batch, 'action': action_batch})
with torch.no_grad():
predictions = agent.predict_batch(processed)
GPU Acceleration#
Move to GPU for faster inference:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = agent.to(device)
processed = {k: v.to(device) for k, v in processed.items()}
with torch.no_grad():
output = agent.predict(processed)
Real-time Inference#
For interactive applications:
import torch
import torchwm
from torchwm import DreamerAgent
class InferenceServer:
def __init__(self):
self.agent = DreamerAgent.from_pretrained("checkpoint").eval()
self.op = torchwm.get_operator("dreamer", image_size=64, action_dim=6)
def predict(self, obs, action):
processed = self.op({'image': obs, 'action': action})
with torch.no_grad():
return self.agent.predict(processed)
server = InferenceServer()
Performance Optimization#
JIT Compilation#
import torch
agent = torch.jit.script(agent)
Memory Efficient Inference#
import torch
with torch.inference_mode():
output = agent.predict(processed)
Exporting Models#
TorchWM installs a deployment-oriented export() method once on torch.nn.Module, so every model class in the library can be exported with the same API. High-level wrapper agents such as Dreamer and PlaNet use the same exporter for their contained modules:
model.export("model.onnx", format="onnx", example_inputs=example_inputs)
agent.export("agent_actor.onnx", format="onnx")
The method supports three formats:
format="onnx"writes an ONNX graph for ONNX Runtime, TensorRT conversion, or other production runtimes.format="torchscript"(aliases:"jit","ts") writes a TorchScript.ptfile.format="tensorrt"(alias:"trt") compiles with the optionaltorch-tensorrtpackage and writes a serialized TorchScript TensorRT module.
Dreamer exports its deterministic actor by default. The exported Dreamer actor
accepts concatenated latent features with shape [batch, stoch_size + deter_size]
and returns actions:
import torch
from torchwm import DreamerAgent
agent = DreamerAgent(env="cartpole_balance")
agent.export("dreamer_actor.onnx", format="onnx")
agent.export("dreamer_actor.pt", format="torchscript")
features = torch.zeros(1, agent.args.stoch_size + agent.args.deter_size)
agent.export(
"dreamer_actor_dynamic.onnx",
format="onnx",
example_inputs=features,
input_names=["features"],
output_names=["actions"],
dynamic_axes={"features": {0: "batch"}, "actions": {0: "batch"}},
)
Export individual components by passing target when the agent provides more
than one deployable module:
agent.export("dreamer_encoder.onnx", format="onnx", target="obs_encoder")
agent.export("dreamer_reward.pt", format="torchscript", target="reward_model")
For any lower-level torch.nn.Module model, pass example_inputs explicitly if TorchWM cannot infer a safe default:
import torch
import torchwm
genie = torchwm.create_model("genie-small", image_size=32)
video = torch.randn(1, 3, genie.num_frames, genie.image_size, genie.image_size)
genie.export("genie_small.onnx", format="onnx", example_inputs=video)
vit = torchwm.VisionTransformer(img_size=[224])
images = torch.randn(1, 3, 224, 224)
vit.export("vit.onnx", format="onnx", example_inputs=images)
Agents that contain multiple deployable modules accept either short target names such as "obs_encoder" or fully qualified paths such as "dreamer.obs_encoder". JEPA exports a ViT encoder target by default, while lower-level JEPA VisionTransformer modules can be exported directly like any other torch.nn.Module.
TensorRT export requires torch-tensorrt in the deployment environment:
agent.export("dreamer_actor_trt.pt", format="tensorrt")
Integration Examples#
With Gym Environments#
import torchwm
from torchwm import DreamerAgent
env = torchwm.make_env("Pendulum-v1", backend="gym")
agent = DreamerAgent.from_pretrained("pendulum_checkpoint")
op = torchwm.get_operator("dreamer", image_size=64, action_dim=env.action_space.shape[0])
obs, _ = env.reset()
done = False
while not done:
action = agent.act(obs) # Get action from agent
obs, reward, done, _, _ = env.step(action)
With Custom Environments#
class CustomEnv:
def step(self, action):
# Your environment logic
return obs, reward, done
env = CustomEnv()
agent = DreamerAgent.from_pretrained("custom_checkpoint")
for episode in range(10):
obs = env.reset()
total_reward = 0
while True:
processed = op({'image': obs, 'action': action})
with torch.no_grad():
next_obs_pred, reward_pred = agent.predict(processed)
# Use predictions for planning/control
action = agent.plan(obs, next_obs_pred, reward_pred)
obs, reward, done = env.step(action)
total_reward += reward
if done:
break
print(f"Episode {episode}: {total_reward}")
Troubleshooting#
Memory Issues#
Use smaller batch sizes
Enable gradient checkpointing
Clear cache:
torch.cuda.empty_cache()
Speed Issues#
Move to GPU
Use JIT compilation
Batch inputs when possible
Accuracy Issues#
Ensure proper preprocessing with operators
Check model loading
Verify input shapes match training