DiT: Diffusion Transformer#

DiT (Diffusion Transformer) applies the transformer architecture to diffusion models for image generation. It treats diffusion as a sequence modeling problem where the transformer predicts noise at each timestep.

Based on paper: Scalable Diffusion Models with Transformers (Peebles & Xie, 2023)

Key Idea#

Instead of using CNNs (like U-Net) for diffusion, DiT uses a Vision Transformer (ViT) architecture:

  1. Tokenization: Split image into patches, convert to tokens

  2. Diffusion as sequence: Treat noise prediction as sequence modeling

  3. Transformer: Process token sequence to predict added noise

  4. DDPM: Generate images by iteratively denoising

Architecture#

DiT Architecture

*Figure 1: DiT architecture overview from the DiT paper (Peebles & Xie, 2023). Shows the transformer-based diffusion model with patch embedding, timestep conditioning, and noise prediction.*

Diffusion Transformer

Noisy image input Patch and position embeddings Timestep and condition tokens DiT transformer blocks Predicted noise output

Components#

1. Patch Embedding#

Converts input image into a sequence of tokens:

  • Image of size (C, H, W) split into (H/P × W/P) patches

  • Each patch linearly embedded to dimension D

  • Add positional embeddings

\[\begin{split}\begin{aligned} \mathbf{x} &\in \mathbb{R}^{C \times H \times W} \\ &\rightarrow \mathbf{tokens} \in \mathbb{R}^{(H/P \times W/P) \times D} \end{aligned}\end{split}\]

2. Timestep & Condition Embeddings#

  • Timestep: sinusoial positional encoding for diffusion timestep t

  • Class label: Optional conditioning for class-conditional generation

  • Both added to each token

3. DiT Blocks#

Transformer blocks with:

  • Self-attention: Capture relationships between patches

  • MLP: Process each patch independently

  • Adaptive Layer Norm (AdaLN): Condition on timestep/class

Variants:

  • AdaLN: Adaptive Layer Norm

  • AdaLN-Single: More parameter-efficient

  • AdaLN-Diagonal: Condition variance too

4. Output Head#

Final layer predicting noise (ε) same dimension as input:

\[\hat{\epsilon} = \mathrm{Linear}(\mathbf{tokens}) \rightarrow \mathbb{R}^{C \times H \times W}\]

Training#

from torchwm import DiTConfig, get_dit_config

cfg = get_dit_config(
    DATASET="CIFAR10",
    BATCH=128,
    EPOCHS=100,
    IMG_SIZE=32,
    WIDTH=384,
    DEPTH=6,
)

# Training uses DDPM noise scheduling
\[q(\mathbf{x}_t \mid \mathbf{x}_0) = \mathcal{N}\left(\sqrt{\bar{\alpha}_t}\,\mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I}\right)\]
\[p(\mathbf{x}_{t-1} \mid \mathbf{x}_t) = \mathcal{N}\left(\mu_\theta(\mathbf{x}_t, t), \sigma_t^2\mathbf{I}\right)\]

Key Hyperparameters#

Parameter

Default

Description

IMG_SIZE

32

Input image size

PATCH

4

Patch size

WIDTH

384

Transformer embedding dim

DEPTH

6

Number of blocks

HEADS

6

Attention heads

TIMESTEPS

1000

Diffusion steps

BETA_START

1e-4

Noise schedule start

BETA_END

0.02

Noise schedule end

Sampling (Generation)#

# Start from random noise
# Iteratively denoise
for t in reversed(range(T)):
    ε = model(x_t, t)  # Predict noise
    x_{t-1} = x_t - sqrt(1-alpha_bar_t) * ε  # Remove predicted noise

# Output: x_0 (generated image)
\[\mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I})\]
\[\epsilon = \mathrm{model}(\mathbf{x}_t, t)\]
\[\mathbf{x}_{t-1} = \mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon\]

Classifier-Free Guidance#

For conditional generation, use classifier-free guidance:

\[\begin{split}\begin{aligned} \epsilon_\mathrm{cond} &= (1 + w) \cdot \epsilon_\mathrm{model}(\mathbf{x}_t, c) \\ &\quad - w \cdot \epsilon_\mathrm{model}(\mathbf{x}_t, \emptyset) \end{aligned}\end{split}\]

where w is guidance weight (typically 1-10).

Comparison to CNN-Based Diffusion#

Aspect

U-Net (DDPM)

DiT (Transformer)

Architecture

CNN with skip connections

ViT

Global attention

Limited

Full

Scalability

Medium

High

Quality

Good

Slightly better

Compute

Efficient

Higher

Applications#

  1. Image generation: High-quality unconditional/class-conditional

  2. Image editing: Inpainting, outpainting

  3. Video generation: Temporal extension

  4. World models: For model-based RL (as observation model)

References#

  • Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers.