Evaluation Guide#
TorchWM provides a general-purpose evaluation package for measuring the quality of generative world models. It implements the three metrics used in the DIAMOND paper (Alonso et al., NeurIPS 2024, Appendix M):
FID (Fréchet Inception Distance) — perceptual quality of individual generated frames
FVD (Fréchet Video Distance) — temporal coherence of generated video clips
LPIPS (Learned Perceptual Image Patch Similarity) — per-frame perceptual similarity
The metrics are model-agnostic: they work with any image or video generation pipeline, not just DIAMOND.
Quick Start#
Python API#
import torch
from evals import FID, FVD, LPIPS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load your real and generated data
real_images = torch.randn(100, 3, 64, 64) # [N, C, H, W] in [0, 1]
gen_images = torch.randn(100, 3, 64, 64)
# FID: compares distributions of real vs generated frames
fid = FID(device=device)
print(f"FID: {fid(real_images, gen_images):.2f}")
# LPIPS: pairwise perceptual similarity
lpips = LPIPS(device=device)
print(f"LPIPS: {lpips(real_images, gen_images):.4f}")
# FVD: requires video tensors [B, C, T, H, W]
real_videos = torch.randn(16, 3, 20, 64, 64)
gen_videos = torch.randn(16, 3, 20, 64, 64)
fvd = FVD(device=device)
print(f"FVD: {fvd(real_videos, gen_videos):.2f}")
CLI#
# Evaluate a DIAMOND world model
torchwm eval --model diamond --checkpoint path/to/model.pt --game Breakout-v5
# Specify which metrics to compute
torchwm eval --model diamond --checkpoint model.pt --metrics fid,lpips
# Record real and generated videos
torchwm eval --model diamond --checkpoint model.pt --record eval_video.mp4
# Full control
torchwm eval --model diamond --checkpoint model.pt \
--game Breakout-v5 \
--num-videos 512 \
--trajectory-length 20 \
--batch-size 32 \
--seed 42 \
--metrics fid,fvd,lpips \
--record eval_video.mp4
Interactive Play#
# Watch the agent play in the real environment
torchwm play --model diamond --checkpoint path/to/model.pt --game Breakout-v5
# Switch to dream mode (TAB) to watch the agent inside its imagination
# Press arrow keys / WASD to override the agent's actions
torchwm play --model diamond --checkpoint model.pt
# Record gameplay video
torchwm play --model diamond --checkpoint model.pt --record gameplay.mp4
FID: Fréchet Inception Distance#
Definition#
FID measures the distance between the distribution of real images and the distribution of generated images in the feature space of a pretrained InceptionV3 network. Lower scores indicate that the two distributions are more similar.
where:
\(\mu_r, \Sigma_r\) are the mean and covariance of Inception features for real images
\(\mu_g, \Sigma_g\) are the mean and covariance of Inception features for generated images
\((\Sigma_r \Sigma_g)^{\frac{1}{2}}\) is the matrix square root of the product
Implementation#
Feature extraction: images are resized to 299x299 and passed through InceptionV3 truncated at the
Mixed_7clayer, producing 2048-dimensional feature vectors.Statistics: mean and covariance are computed from all feature vectors.
Distance: the Fréchet distance between the two Gaussians is computed using
scipy.linalg.sqrtmwith a small diagonal regularizer (\(\epsilon = 10^{-6}\)) for numerical stability.
Usage#
from evals import FID
fid = FID(device=device, batch_size=64)
# real_images, gen_images: [N, C, H, W] float tensors in [0, 1]
score = fid(real_images, gen_images)
Interpretation#
FID Score |
Quality |
|---|---|
0 |
Perfect (identical distributions) |
< 10 |
Excellent — nearly indistinguishable |
10–50 |
Good — visible differences but similar structure |
50–200 |
Moderate — clear distribution mismatch |
> 200 |
Poor — very different distributions |
Caveat: FID is sensitive to the number of samples. Always compare scores computed with the same number of images.
FVD: Fréchet Video Distance#
Definition#
FVD extends the Fréchet distance to video by comparing feature distributions from a video recognition backbone (R3D-18 pretrained on Kinetics-400). It captures both per-frame quality and temporal dynamics.
The formula is identical to FID, but the features come from a 3D convolutional network that processes spatiotemporal volumes.
Implementation#
Clip sampling: input videos of arbitrary length are split into 16-frame clips (shorter videos are padded by repeating).
Feature extraction: each clip is resized to 112x112, normalized with Kinetics statistics, and passed through R3D-18 truncated before the final FC layer, producing 512-dimensional feature vectors.
Statistics and distance: computed identically to FID.
Usage#
from evals import FVD
fvd = FVD(device=device, batch_size=16, clip_length=16)
# real_videos, gen_videos: [N, C, T, H, W] float tensors in [0, 1]
score = fvd(real_videos, gen_videos)
Interpretation#
FVD Score |
Quality |
|---|---|
0 |
Perfect |
< 50 |
Excellent temporal coherence |
50–300 |
Good |
300–1000 |
Moderate |
> 1000 |
Poor |
Caveat: FVD requires at least as many clips as feature dimensions (512) for a stable covariance estimate. With very small sample sizes, the matrix square root may become numerically unstable.
LPIPS: Learned Perceptual Image Patch Similarity#
Definition#
LPIPS measures the perceptual distance between pairs of images using deep features from a pretrained VGG16 network. Unlike FID (which compares distributions), LPIPS compares individual image pairs — it is a reference metric.
Lower scores indicate higher perceptual similarity. LPIPS correlates better with human judgment than pixel-level metrics like MSE or PSNR.
where \(f^l(x)\) are features from layer \(l\) of VGG16, and \(w_l\) are learned per-channel weights (uniform in our implementation).
Implementation#
Feature extraction: images are passed through VGG16-bn, extracting features from layers
relu1_2,relu2_2,relu3_3, andrelu4_3.Normalization: inputs are scaled from [0, 1] to [-1, 1] (VGG16 training range).
Distance: for each layer, the squared L2 difference is spatially averaged and summed across channels. Layer scores are averaged to produce the final score.
Usage#
from evals import LPIPS
lpips = LPIPS(device=device, batch_size=64)
# images_a, images_b: [N, C, H, W] float tensors in [0, 1]
# Must have the same N — this is pairwise
score = lpips(images_a, images_b)
Interpretation#
LPIPS Score |
Perceptual Similarity |
|---|---|
0 |
Identical images |
< 0.05 |
Very similar |
0.05–0.2 |
Noticeable differences |
0.2–0.5 |
Clearly different |
> 0.5 |
Very different |
Python API Reference#
evals Package#
evals/
__init__.py # Exports FID, FVD, LPIPS
fid.py # Fréchet Inception Distance
fvd.py # Fréchet Video Distance
lpips.py # Learned Perceptual Image Patch Similarity
diamond_utils.py # DIAMOND-specific trajectory generation utilities
Classes#
Class |
Metric |
Input Shape |
Feature Backbone |
Feature Dim |
|---|---|---|---|---|
|
Fréchet Inception Distance |
|
InceptionV3 (Mixed_7c) |
2048 |
|
Fréchet Video Distance |
|
R3D-18 (avgpool) |
512 |
|
Learned Perceptual Similarity |
|
VGG16-bn (4 layers) |
— |
Common Parameters#
device:torch.device— CPU or CUDA device for feature extraction.batch_size:int— number of images/clips processed at once. Lower if GPU memory is limited.
FVD-specific#
clip_length:int(default 16) — number of frames per video clip. R3D-18 was trained on 16-frame clips.
Model-Specific Evaluation Scripts#
DIAMOND (scripts/eval_diamond.py)#
The DIAMOND eval script provides a full evaluation pipeline:
Load checkpoint: reconstructs the diffusion model, EDM preconditioner, and Euler sampler from the saved checkpoint.
Collect real trajectories: runs a random-policy agent in the Atari environment to collect ground-truth frame sequences.
Generate imagined trajectories: conditioned on real frames + actions, the diffusion model autoregressively generates future frames using the EDM sampler.
Compute metrics: runs FID, FVD, and LPIPS comparing real vs generated frames.
Extending to Other Models#
To add evaluation for a new model:
Create
scripts/eval_<model>.pywith arun_eval()function matching the signature in TorchWM CLI.Register it in
EVAL_MODULESintools/cli.py:
EVAL_MODULES = {
"diamond": "scripts.eval_diamond",
"my_model": "scripts.eval_my_model",
}
The torchwm eval command will then accept --model my_model.
Trajectory Collection Utilities#
The evals.diamond_utils module provides helper functions for DIAMOND evaluation:
generate_trajectories()— autoregressively generates frames from a diffusion model conditioned on real trajectory data.collect_real_trajectories_from_env()— collects ground-truth trajectories from an Atari environment using random actions.
These are DIAMOND-specific but serve as a template for other model types.
Video Recording#
Both torchwm eval and torchwm play support the --record flag to save videos.
Eval Recording#
torchwm eval --model diamond --checkpoint model.pt --record results.mp4
This saves two files:
results_real.mp4— real environment trajectoriesresults_gen.mp4— model-generated trajectories
Play Recording#
torchwm play --model diamond --checkpoint model.pt --record gameplay.mp4 --record-fps 20
Records the gameplay window in real-time. The recording includes the HUD overlay (mode indicator, reward, action).
Tips and Best Practices#
Stable Metric Computation#
Use enough samples: FID and FVD estimate covariances which require \(N \gg D\) (samples much greater than feature dimension).
FID: at least 10,000 images for stable estimates
FVD: at least 1,000 clips (each clip = 16 frames)
LPIPS: pairwise, so stable with fewer samples
Match resolution: all metrics resize inputs internally, but extreme mismatches between training and evaluation resolution can bias results.
Use the same random seed when comparing different models to control for sampling noise:
torchwm eval --model diamond --checkpoint model_a.pt --seed 42 torchwm eval --model diamond --checkpoint model_b.pt --seed 42
Choosing Metrics#
If you want to measure… |
Use |
|---|---|
Perceptual quality of generated frames |
FID |
Temporal coherence of videos |
FVD |
Per-frame accuracy vs a ground-truth video |
LPIPS |
All three (DIAMOND paper standard) |
|
Interpreting Combined Results#
Low FID + high FVD: frames look good individually but lack temporal consistency (flickering, jittering).
High FID + low FVD: frames are distorted but coherent across time (e.g., all frames are blurry in the same way).
Low FID + low FVD + low LPIPS: high-quality world model with good dynamics.