# JEPA: Joint Embedding Predictive Architecture
JEPA is a self-supervised learning method that learns visual representations by predicting
representations in abstract latent space, without relying on generative modeling or
hand-crafted data augmentations.
Based on paper: [I-JEPA: Image-based Joint Embedding Predictive Architecture](https://arxiv.org/abs/2301.08243) (Bardes et al., 2023)
```{contents} Contents
:depth: 3
```
## Overview
I-JEPA learns visual representations **without**:
- Hand-crafted data augmentations (color jitter, grayscale, etc.)
- Negative examples (contrastive learning)
- Pixel-level reconstruction (autoencoders, MAE)
Instead, it predicts the **latent representation** of one image region from
another region using a Vision Transformer (ViT) backbone. The predictor operates
in embedding space, not pixel space, which forces the model to learn
semantically meaningful features.
```{mermaid}
graph TD
A["Input image x"] --> B["Context encoder f_θ"]
A --> C["Target encoder f_θ̄ (EMA)"]
B --> D["Context patches (masked)"]
C --> E["Target patches"]
D --> F["Predictor g_φ"]
E --> G["Target representation sg(y_target)"]
F --> H["Predicted representation ŷ"]
H --> I["L2 loss"]
G --> I
I --> J["sg: stop-gradient through target encoder"]
```
## Architecture
### High-level diagram
JEPA Architecture
Current frame encoder
→
Predictor token
→
Predicted representation
→
MSE loss
Future frame frozen encoder
→
Target representation
→
MSE loss
### Vision Transformer (ViT)
The backbone encoder in `world_models.models.vit` is a Vision Transformer
following the standard ViT architecture with JEPA-specific modifications.
**Patch embedding:**
The input image `x ∈ ℝ^{3×H×W}` is split into patches of size `P × P`,
producing `N = (H/P) × (W/P)` patches. Each patch is linearly projected to
`embed_dim`:
```{math}
\text{patches} \in \mathbb{R}^{N \times (3 \cdot P^2)} \to
\text{tokens} \in \mathbb{R}^{N \times D}
```
**Transformer blocks:**
Each block consists of:
1. **LayerNorm** → Multi-Head Self-Attention → residual
2. **LayerNorm** → MLP (GELU, 4× hidden) → residual
3. **DropPath** (stochastic depth) regularization during training
**Key architectural details:**
- No class token — all patch tokens are used
- Pre-normalization (LayerNorm before attention and MLP)
- Fixed sin-cos positional embeddings (not learned)
### Target Encoder (EMA)
The target encoder `f_{\bar{θ}}` has the same architecture as the context
encoder `f_θ` but its weights are an **exponential moving average** (EMA) of
the context encoder's weights:
```{math}
\bar{θ} \leftarrow m \cdot \bar{θ} + (1 - m) \cdot θ
```
where `m` is the momentum coefficient (default: cosine schedule from 0.996 to
1.0). The target encoder receives `stop-gradient`.
### Predictor
The predictor `g_φ` is a smaller transformer (default 6 layers, 384 dim) that
predicts target patch representations from context patch representations.
Key design:
- **Lighter than the encoder**: fewer layers, smaller hidden dim
- **Positional embeddings for all patches**: the predictor knows which target patches to predict
- **Mask tokens for target positions**: learnable embeddings substituted for masked patches
### Masking
I-JEPA uses **multi-block masking**: random rectangular blocks are masked
rather than individual patches.
```python
config.num_enc_masks = 1 # Number of context blocks
config.enc_mask_scale = (0.15, 0.2) # Context covers 15-20% of image
config.num_pred_masks = 4 # Number of target blocks
config.pred_mask_scale = (0.15, 0.2) # Each target is 15-20%
config.aspect_ratio = (0.75, 1.5) # Block aspect ratio range
```
The predictor sees the context patches and must predict the representation of
each target block's patches. With 4 target blocks and context covering ~15-20%,
most of the image must be predicted from a small visible region.
## Training
### Loss Function
The I-JEPA loss is the L2 distance between predicted and target representations,
averaged over masked patches:
```{math}
\mathcal{L}_{\text{JEPA}} =
\frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}}
\left\| g_φ(f_θ(x)_i + \text{mask\_token}, \text{pos}_i) -
\text{sg}(f_{\bar{θ}}(x)_i) \right\|_2^2
```
### Optimization
```{math}
\begin{aligned}
\text{Context encoder: } & θ \leftarrow \text{optimizer}(θ, \nabla_θ \mathcal{L}) \\
\text{Predictor: } & φ \leftarrow \text{optimizer}(φ, \nabla_φ \mathcal{L}) \\
\text{Target encoder: } & \bar{θ} \leftarrow m \cdot \bar{θ} + (1 - m) \cdot θ
\end{aligned}
```
### Learning Rate Schedule
Three-phase schedule:
1. **Warmup** (0 → `warmup_epochs`): Linear increase from 0 to `lr`
2. **Cosine decay** (`warmup_epochs` → `epochs`): Cosine annealing to `min_lr`
3. **Constant**: After `epochs`, remains at `min_lr`
## Usage in TorchWM
### Quick start
```python
import torchwm
agent = torchwm.create_model(
"jepa",
dataset="imagenet",
batch_size=64,
epochs=100,
)
agent.train()
```
### Using config directly
```python
from torchwm import JEPAAgent, JEPAConfig
cfg = JEPAConfig()
cfg.dataset = "imagenet1k"
cfg.root_path = "/data/imagenet"
cfg.image_folder = "train"
cfg.batch_size = 64
cfg.epochs = 100
cfg.lr = 1.5e-4
agent = JEPAAgent(cfg)
agent.train()
```
### Data pipeline
```python
cfg.dataset = "imagenet1k" # ImageNet-1K (requires download)
cfg.root_path = "/data/imagenet"
# Or use a generic image folder:
cfg.dataset = "imagefolder"
cfg.root_path = "./my_dataset"
cfg.image_folder = "train"
# Or CIFAR-10 for testing:
cfg.dataset = "cifar10"
cfg.download = True
```
```{note}
JEPA does NOT rely on heavy augmentation like contrastive methods. The
core learning signal comes from the masking prediction task, not from image distortion.
```
### CLI
```bash
torchwm train jepa --dataset imagenet1k --epochs 100 --batch_size 64
```
## Config Reference
| Parameter | Default | Description |
|-----------|---------|-------------|
| `image_size` | 224 | Input image size |
| `patch_size` | 16 | ViT patch size |
| `embed_dim` | 768 | Embedding dimension |
| `num_layers` | 12 | Transformer layers |
| `num_heads` | 12 | Attention heads |
| `mask_ratio` | 0.75 | Fraction of patches to mask |
| `predict_horizon` | 1 | Frames to predict ahead |
| `batch_size` | 64 | Training batch size |
| `lr` | 1.5e-4 | Peak learning rate |
| `min_lr` | 1e-6 | Minimum learning rate |
| `warmup_epochs` | 40 | Linear warmup from 0 to `lr` |
| `weight_decay` | 0.05 | AdamW weight decay |
| `clip_grad` | 1.0 | Gradient clipping norm |
| `epochs` | 100 | Total training epochs |
| `accum_iter` | 1 | Gradient accumulation steps |
| `ema` | (0.996, 1.0) | EMA momentum range (cosine schedule) |
## Inference and Downstream Tasks
```python
cfg.eval = True
cfg.read_checkpoint = "./output/checkpoint.pth"
```
### Linear probing protocol
| Method | Top-1 Accuracy (ViT-B/16) |
|--------|---------------------------|
| I-JEPA | 72.4% |
| MAE | 68.5% |
| iBOT | 74.7% |
| DINOv2 | 78.3% |
## I-JEPA vs V-JEPA
| Aspect | I-JEPA (Image) | V-JEPA (Video) |
|--------|----------------|-----------------|
| Input | Single image | Video clip |
| Masking | Spatial block masking | Spatio-temporal tube masking |
| Task | Predict masked patch latents | Predict future frame latents |
| Predictor | Transformer | Spatio-temporal transformer |
## Common Pitfalls
### Predictor collapse
The predictor outputs a constant regardless of input.
**Fixes:**
- Ensure EMA starts close to 1.0 (default: 0.996)
- Verify predictor output variance is non-zero
### Representation collapse
All patches map to nearly identical representations.
**Fixes:**
- Use multi-block masking (not random patch masking)
- Check the feature covariance matrix
### Memory usage
ViT-B/16 with 224×224 creates 196 patch tokens. Batch size 64 requires ~16 GB GPU.
**Tips:**
- Enable `gradient_checkpointing = True`
- Reduce `batch_size` and increase `accum_iter`
### Slow convergence
JEPA requires long warmup (40 epochs) and many total epochs (100–300).
**Tips:**
- Use the cosine schedule for EMA momentum
- Expect 48+ hours on 4× GPUs for ViT-B/16 at 100 epochs
## Comparison to Other Methods
| Method | What it predicts | Approach |
|--------|------------------|----------|
| Autoencoder | Pixels | Reconstruction |
| VAE | Pixels | Generative |
| MAE | Pixels | Masked modeling |
| JEPA | Latents | Predictive coding |
| IRIS | Tokens | Transformer dynamics |
## References
- Bardes, A., Ponce, J., & LeCun, Y. (2023). I-JEPA: Image-based Joint Embedding Predictive Architecture. *arXiv:2301.08243.*
- Assran, M., et al. (2023). Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture. *CVPR 2023.*
- Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. *ICLR 2021.*