JEPA: Joint Embedding Predictive Architecture#

JEPA is a self-supervised learning method that learns visual representations by predicting representations in abstract latent space, without relying on generative modeling or hand-crafted data augmentations.

Based on paper: I-JEPA: Image-based Joint Embedding Predictive Architecture (Bardes et al., 2023)

Overview#

I-JEPA learns visual representations without:

  • Hand-crafted data augmentations (color jitter, grayscale, etc.)

  • Negative examples (contrastive learning)

  • Pixel-level reconstruction (autoencoders, MAE)

Instead, it predicts the latent representation of one image region from another region using a Vision Transformer (ViT) backbone. The predictor operates in embedding space, not pixel space, which forces the model to learn semantically meaningful features.

        graph TD
    A["Input image x"] --> B["Context encoder f_θ"]
    A --> C["Target encoder f_θ̄ (EMA)"]
    B --> D["Context patches (masked)"]
    C --> E["Target patches"]
    D --> F["Predictor g_φ"]
    E --> G["Target representation sg(y_target)"]
    F --> H["Predicted representation ŷ"]
    H --> I["L2 loss"]
    G --> I
    I --> J["sg: stop-gradient through target encoder"]
    

Architecture#

High-level diagram#

JEPA Architecture

Current frame encoder Predictor token Predicted representation MSE loss
Future frame frozen encoder Target representation MSE loss

Vision Transformer (ViT)#

The backbone encoder in world_models.models.vit is a Vision Transformer following the standard ViT architecture with JEPA-specific modifications.

Patch embedding:

The input image x ℝ^{3×H×W} is split into patches of size P × P, producing N = (H/P) × (W/P) patches. Each patch is linearly projected to embed_dim:

\[\text{patches} \in \mathbb{R}^{N \times (3 \cdot P^2)} \to \text{tokens} \in \mathbb{R}^{N \times D}\]

Transformer blocks:

Each block consists of:

  1. LayerNorm → Multi-Head Self-Attention → residual

  2. LayerNorm → MLP (GELU, 4× hidden) → residual

  3. DropPath (stochastic depth) regularization during training

Key architectural details:

  • No class token — all patch tokens are used

  • Pre-normalization (LayerNorm before attention and MLP)

  • Fixed sin-cos positional embeddings (not learned)

Target Encoder (EMA)#

The target encoder f_{\bar{θ}} has the same architecture as the context encoder f_θ but its weights are an exponential moving average (EMA) of the context encoder’s weights:

\[\bar{θ} \leftarrow m \cdot \bar{θ} + (1 - m) \cdot θ\]

where m is the momentum coefficient (default: cosine schedule from 0.996 to 1.0). The target encoder receives stop-gradient.

Predictor#

The predictor g_φ is a smaller transformer (default 6 layers, 384 dim) that predicts target patch representations from context patch representations.

Key design:

  • Lighter than the encoder: fewer layers, smaller hidden dim

  • Positional embeddings for all patches: the predictor knows which target patches to predict

  • Mask tokens for target positions: learnable embeddings substituted for masked patches

Masking#

I-JEPA uses multi-block masking: random rectangular blocks are masked rather than individual patches.

config.num_enc_masks = 1           # Number of context blocks
config.enc_mask_scale = (0.15, 0.2)   # Context covers 15-20% of image
config.num_pred_masks = 4          # Number of target blocks
config.pred_mask_scale = (0.15, 0.2)  # Each target is 15-20%
config.aspect_ratio = (0.75, 1.5)     # Block aspect ratio range

The predictor sees the context patches and must predict the representation of each target block’s patches. With 4 target blocks and context covering ~15-20%, most of the image must be predicted from a small visible region.

Training#

Loss Function#

The I-JEPA loss is the L2 distance between predicted and target representations, averaged over masked patches:

\[\mathcal{L}_{\text{JEPA}} = \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \left\| g_φ(f_θ(x)_i + \text{mask\_token}, \text{pos}_i) - \text{sg}(f_{\bar{θ}}(x)_i) \right\|_2^2\]

Optimization#

\[\begin{split}\begin{aligned} \text{Context encoder: } & θ \leftarrow \text{optimizer}(θ, \nabla_θ \mathcal{L}) \\ \text{Predictor: } & φ \leftarrow \text{optimizer}(φ, \nabla_φ \mathcal{L}) \\ \text{Target encoder: } & \bar{θ} \leftarrow m \cdot \bar{θ} + (1 - m) \cdot θ \end{aligned}\end{split}\]

Learning Rate Schedule#

Three-phase schedule:

  1. Warmup (0 → warmup_epochs): Linear increase from 0 to lr

  2. Cosine decay (warmup_epochsepochs): Cosine annealing to min_lr

  3. Constant: After epochs, remains at min_lr

Usage in TorchWM#

Quick start#

import torchwm

agent = torchwm.create_model(
    "jepa",
    dataset="imagenet",
    batch_size=64,
    epochs=100,
)
agent.train()

Using config directly#

from torchwm import JEPAAgent, JEPAConfig

cfg = JEPAConfig()
cfg.dataset = "imagenet1k"
cfg.root_path = "/data/imagenet"
cfg.image_folder = "train"
cfg.batch_size = 64
cfg.epochs = 100
cfg.lr = 1.5e-4

agent = JEPAAgent(cfg)
agent.train()

Data pipeline#

cfg.dataset = "imagenet1k"     # ImageNet-1K (requires download)
cfg.root_path = "/data/imagenet"

# Or use a generic image folder:
cfg.dataset = "imagefolder"
cfg.root_path = "./my_dataset"
cfg.image_folder = "train"

# Or CIFAR-10 for testing:
cfg.dataset = "cifar10"
cfg.download = True

Note

JEPA does NOT rely on heavy augmentation like contrastive methods. The core learning signal comes from the masking prediction task, not from image distortion.

CLI#

torchwm train jepa --dataset imagenet1k --epochs 100 --batch_size 64

Config Reference#

Parameter

Default

Description

image_size

224

Input image size

patch_size

16

ViT patch size

embed_dim

768

Embedding dimension

num_layers

12

Transformer layers

num_heads

12

Attention heads

mask_ratio

0.75

Fraction of patches to mask

predict_horizon

1

Frames to predict ahead

batch_size

64

Training batch size

lr

1.5e-4

Peak learning rate

min_lr

1e-6

Minimum learning rate

warmup_epochs

40

Linear warmup from 0 to lr

weight_decay

0.05

AdamW weight decay

clip_grad

1.0

Gradient clipping norm

epochs

100

Total training epochs

accum_iter

1

Gradient accumulation steps

ema

(0.996, 1.0)

EMA momentum range (cosine schedule)

Inference and Downstream Tasks#

cfg.eval = True
cfg.read_checkpoint = "./output/checkpoint.pth"

Linear probing protocol#

Method

Top-1 Accuracy (ViT-B/16)

I-JEPA

72.4%

MAE

68.5%

iBOT

74.7%

DINOv2

78.3%

I-JEPA vs V-JEPA#

Aspect

I-JEPA (Image)

V-JEPA (Video)

Input

Single image

Video clip

Masking

Spatial block masking

Spatio-temporal tube masking

Task

Predict masked patch latents

Predict future frame latents

Predictor

Transformer

Spatio-temporal transformer

Common Pitfalls#

Predictor collapse#

The predictor outputs a constant regardless of input.

Fixes:

  • Ensure EMA starts close to 1.0 (default: 0.996)

  • Verify predictor output variance is non-zero

Representation collapse#

All patches map to nearly identical representations.

Fixes:

  • Use multi-block masking (not random patch masking)

  • Check the feature covariance matrix

Memory usage#

ViT-B/16 with 224×224 creates 196 patch tokens. Batch size 64 requires ~16 GB GPU.

Tips:

  • Enable gradient_checkpointing = True

  • Reduce batch_size and increase accum_iter

Slow convergence#

JEPA requires long warmup (40 epochs) and many total epochs (100–300).

Tips:

  • Use the cosine schedule for EMA momentum

  • Expect 48+ hours on 4× GPUs for ViT-B/16 at 100 epochs

Comparison to Other Methods#

Method

What it predicts

Approach

Autoencoder

Pixels

Reconstruction

VAE

Pixels

Generative

MAE

Pixels

Masked modeling

JEPA

Latents

Predictive coding

IRIS

Tokens

Transformer dynamics

References#

  • Bardes, A., Ponce, J., & LeCun, Y. (2023). I-JEPA: Image-based Joint Embedding Predictive Architecture. arXiv:2301.08243.

  • Assran, M., et al. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. CVPR 2023.

  • Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.