DiT: Diffusion Transformer#

DiT (Diffusion Transformer) applies the transformer architecture to diffusion models for image generation. It treats diffusion as a sequence modeling problem where the transformer predicts noise at each timestep.

Based on paper: Scalable Diffusion Models with Transformers (Peebles & Xie, 2023)

Key Idea#

Instead of using CNNs (like U-Net) for diffusion, DiT uses a Vision Transformer (ViT) architecture:

  1. Tokenization: Split image into patches, convert to tokens

  2. Diffusion as sequence: Treat noise prediction as sequence modeling

  3. Transformer: Process token sequence to predict added noise

  4. DDPM: Generate images by iteratively denoising

Architecture#

┌─────────────────────────────────────────────────────────────────────┐
│                      DiT Architecture                                │
│                                                                      │
│  Input: x_t (noisy image at timestep t)                             │
│                                                                      │
│  ┌─────┐    ┌─────────────────────────────────────────────────────┐  │
│  │Patch│    │      Linear Embedding                               │  │
│  │Embed│ ──►│  x -> patch tokens + t_emb + c_emb                 │  │
│  └─────┘    └─────────────────────────────────────────────────────┘  │
│                          │                                           │
│                          ▼                                           │
│  ┌────────────────────────────────────────────────────────────────┐  │
│  │                   Transformer Blocks                            │  │
│  │                                                                 │  │
│  │   ┌─────────────┐    ┌─────────────┐    ┌─────────────┐        │  │
│  │   │   Block 1   │───►│   Block 2   │───►│   Block N   │        │  │
│  │   │ (self-att)  │    │ (self-att)  │    │ (self-att)  │        │  │
│  │   │  + MLP      │    │  + MLP      │    │  + MLP      │        │  │
│  │   └─────────────┘    └─────────────┘    └─────────────┘        │  │
│  │                                                                 │  │
│  └────────────────────────────────────────────────────────────────┘  │
│                          │                                           │
│                          ▼                                           │
│  ┌────────────────────────────────────────────────────────────────┐  │
│  │                   Output Head                                    │  │
│  │            Linear → predicted noise ε                           │  │
│  └────────────────────────────────────────────────────────────────┘  │
└─────────────────────────────────────────────────────────────────────┘

Components#

1. Patch Embedding#

Converts input image into a sequence of tokens:

  • Image of size (C, H, W) split into (H/P × W/P) patches

  • Each patch linearly embedded to dimension D

  • Add positional embeddings

x ∈ ℝ^{C×H×W} → tokens ∈ ℝ^{(H/P × W/P) × D}

2. Timestep & Condition Embeddings#

  • Timestep: sinusoial positional encoding for diffusion timestep t

  • Class label: Optional conditioning for class-conditional generation

  • Both added to each token

3. DiT Blocks#

Transformer blocks with:

  • Self-attention: Capture relationships between patches

  • MLP: Process each patch independently

  • Adaptive Layer Norm (AdaLN): Condition on timestep/class

Variants:

  • AdaLN: Adaptive Layer Norm

  • AdaLN-Single: More parameter-efficient

  • AdaLN-Diagonal: Condition variance too

4. Output Head#

Final layer predicting noise (ε) same dimension as input:

ε_pred = Linear(tokens) → ℝ^{C×H×W}

Training#

from world_models.configs import DiTConfig, get_dit_config

cfg = get_dit_config(
    DATASET="CIFAR10",
    BATCH=128,
    EPOCHS=100,
    IMG_SIZE=32,
    WIDTH=384,
    DEPTH=6,
)

# Training uses DDPM noise scheduling
# Forward: q(x_t | x_0) = N(√ᾱ_t x_0, (1-ᾱ_t)I)
# Reverse: p(x_{t-1} | x_t) = N(μ_θ(x_t,t), σ_t²I)

Key Hyperparameters#

Parameter

Default

Description

IMG_SIZE

32

Input image size

PATCH

4

Patch size

WIDTH

384

Transformer embedding dim

DEPTH

6

Number of blocks

HEADS

6

Attention heads

TIMESTEPS

1000

Diffusion steps

BETA_START

1e-4

Noise schedule start

BETA_END

0.02

Noise schedule end

Sampling (Generation)#

# Start from random noise
x_T ~ N(0, I)

# Iteratively denoise
for t in reversed(range(T)):
    ε = model(x_t, t)  # Predict noise
    x_{t-1} = x_t - (1-ᾱ_t) * ε  # Remove predicted noise

# Output: x_0 (generated image)

Classifier-Free Guidance#

For conditional generation, use classifier-free guidance:

ε_cond = (1+w)·ε_model(x_t, c) - w·ε_model(x_t, ∅)

where w is guidance weight (typically 1-10).

Comparison to CNN-Based Diffusion#

Aspect

U-Net (DDPM)

DiT (Transformer)

Architecture

CNN with skip connections

ViT

Global attention

Limited

Full

Scalability

Medium

High

Quality

Good

Slightly better

Compute

Efficient

Higher

Applications#

  1. Image generation: High-quality unconditional/class-conditional

  2. Image editing: Inpainting, outpainting

  3. Video generation: Temporal extension

  4. World models: For model-based RL (as observation model)

References#

  • Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers.