JEPA: Joint Embedding Predictive Architecture#
JEPA is a self-supervised learning method that learns visual representations by predicting future representations in latent space, without relying on generative modeling.
Based on paper: JEPA: Joint Embedding Predictive Architecture (Bardes et al., 2022)
Key Idea#
Instead of predicting pixels (like autoencoders) or reconstructing images, JEPA predicts latent representations of future frames from past frames:
Encoder: Encodes current frame into latent space
Predictor: Predicts future latent representation
Loss: MSE between predicted and actual future latents
This approach avoids the complexity of pixel-level generation while learning rich representations.
Architecture#
┌─────────────────────────────────────────────────────────────────────┐
│ JEPA Architecture │
│ │
│ Frame t Frame t+k │
│ │ │ │
│ ▼ ▼ │
│ ┌───────┐ ┌───────┐ │
│ │Enc_s │ │Enc_s │ │
│ │(target)│ │(target)│ │
│ └───┬───┘ └───┬───┘ │
│ │ │ │
│ ▼ │ │
│ ┌───────┐ │ │
│ │ Target│ ◄──────────┤ (freeze) │
│ │ z_t │ │ │
│ └───┬───┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌───────┐ ┌────▼────┐ │
│ │Predic │ │ Predict │ │
│ │ token │ ────► │ z_t' │ │
│ └───┬───┘ └────┬────┘ │
│ │ │ │
│ │ ▼ │
│ │ ┌────────┐ │
│ └──────────► │ Loss │ = ||z_t' - z_t||² │
│ └────────┘ │
└─────────────────────────────────────────────────────────────────────┘
Components#
1. Encoder (Target Encoder)#
Encodes input images into latent representations. Two variants available:
Spatial Encoder: Processes full images, preserves spatial structure
Temporal Encoder: Processes video sequences, captures motion
2. Predictor#
Predicts future latent representations from current encoding:
Uses masked prediction (similar to masked autoencoders)
Can use spatial/temporal masking strategies
Architecture: ViT-style transformer with masked prediction head
3. Loss Functions#
Main Loss: MSE between predicted and target representations
L = ||predict(target) - target||²
Additional Losses (configurable):
Variance/Invariance regularization
Cross-view consistency for multi-frame input
Training#
from world_models.models import JEPAAgent
from world_models.configs import JEPAConfig
cfg = JEPAConfig()
cfg.dataset = "imagefolder"
cfg.root_path = "./data"
cfg.image_folder = "train"
cfg.epochs = 100
agent = JEPAAgent(cfg)
agent.train()
Key Hyperparameters#
Parameter |
Default |
Description |
|---|---|---|
|
224 |
Input image size |
|
16 |
ViT patch size |
|
768 |
Embedding dimension |
|
12 |
Transformer layers |
|
12 |
Attention heads |
|
0.75 |
Fraction of patches to mask |
|
1 |
Frames to predict ahead |
|
64 |
Training batch size |
|
1e-4 |
Learning rate |
Masking Strategies#
JEPA supports multiple masking approaches:
Block masking: Random rectangular regions
Random masking: Random individual patches
Temporal masking: Mask future frames (for video)
┌─────────────────────────────────────────────────────┐
│ Masking Strategies │
├─────────────────────────────────────────────────────┤
│ │
│ Original Block Mask Random Mask Temporal │
│ ┌───┬───┐ ┌───┬───▓──┐ ┌─▓─┬─▓─┼▓─┐ ┌───┬───┐ │
│ │ A │ B │ │ A │ C │ │ A │ B │ C │ │ A │ B │ │
│ ├───┼───┤ ├───┼───▓──┤ ├─▓─┼─▓─┼▓─┤ ├───┼─▓─┤ │
│ │ C │ D │ │ C │ D │ │ C │ D │ A │ │ C │ D │ │
│ └───┴───┘ └───┴───▓──┘ └───▓─┴─▓─┴─▓─┘ └───┴─▓─┘ │
│ │
│ Predict C Predict B Predict BCD Predict C │
│ from A,B from A from A from A,B │
└─────────────────────────────────────────────────────┘
Uses for Learned Representations#
JEPA representations can be used for:
Downstream tasks: Fine-tune for classification/detection
World models: Use as encoder for model-based RL
Planning: Predict future states for MPC/trajectory optimization
Representation learning: Pre-train for transfer
Comparison to Other Methods#
Method |
What it predicts |
Approach |
|---|---|---|
Autoencoder |
Pixels |
Reconstruction |
VAE |
Pixels |
Generative |
MAE |
Pixels |
Masked modeling |
JEPA |
Latents |
Predictive coding |
IRIS |
Tokens |
Transformer dynamics |
References#
Bardes, A., Ponce, J., & LeCun, Y. (2022). JEPA: Joint Embedding Predictive Architecture.