DiT: Diffusion Transformer and Diffusion Models#

This page covers the diffusion-based models in TorchWM: DDPM for image generation, DiT for scalable transformer-based diffusion, and DIAMOND for diffusion world models in reinforcement learning.

Based on papers:

Overview#

Diffusion models learn to generate data by reversing a gradual noising process. In TorchWM, diffusion models serve two purposes:

  1. Image/video generation (DDPM, DiT): High-quality unconditional or conditional sample generation.

  2. World models for RL (DIAMOND): Diffusion-based dynamics models that predict future observations.

        graph LR
    A["Data x₀"] --> B["Forward: q(x_t|x₀)"]
    B --> C["..."]
    C --> D["x_T ~ N(0, I)"]
    D --> E["Reverse: p_θ(x_{t-1}|x_t)"]
    E --> F["..."]
    F --> G["Generated x₀"]
    

DDPM: Denoising Diffusion Probabilistic Models#

Forward Process#

The forward (diffusion) process gradually adds Gaussian noise to data over T timesteps according to a fixed variance schedule:

\[q(x_t | x_{t-1}) = \mathcal{N}\left(\sqrt{1 - \beta_t}\, x_{t-1},\; \beta_t I\right)\]

We can sample x_t directly from x_0:

\[q(x_t | x_0) = \mathcal{N}\left(\sqrt{\bar{\alpha}_t}\, x_0,\; (1 - \bar{\alpha}_t) I\right)\]

where α_t = 1 - β_t and \bar{α}_t = ∏_{s=1}^{t} α_s. Using the reparameterization trick:

\[x_t = \sqrt{\bar{\alpha}_t}\, x_0 + \sqrt{1 - \bar{\alpha}_t}\, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]

Reverse Process#

The reverse process learns to denoise. Starting from pure noise x_T N(0, I):

\[p_θ(x_{t-1} | x_t) = \mathcal{N}\left( \mu_θ(x_t, t),\; \sigma_t^2 I \right)\]

Training Objective#

The simplified DDPM loss trains the model to predict the noise ε at each timestep:

\[\mathcal{L}_{\text{DDPM}} = \mathbb{E}_{t, x_0, \epsilon} \left[ \left\| \epsilon - \epsilon_θ\left(\sqrt{\bar{\alpha}_t}\, x_0 + \sqrt{1 - \bar{\alpha}_t}\, \epsilon,\; t\right) \right\|^2 \right]\]
x0 = batch                       # clean images
t = randint(0, T)                # random timestep
eps = randn_like(x0)             # random noise
xt = sqrt(alpha_bar[t]) * x0 + sqrt(1 - alpha_bar[t]) * eps
eps_pred = model(xt, t)          # predict noise
loss = mse(eps_pred, eps)        # simple noise-prediction loss
loss.backward()

Sampling#

x = torch.randn(shape)           # pure noise
for t in reversed(range(T)):
    eps_pred = model(x, t)       # predict noise
    x = (x - sqrt(1 - alpha_bar[t]) * eps_pred) / sqrt(alpha_bar[t])
    if t > 0:
        x += sigma[t] * torch.randn_like(x)
return x                         # generated image

DiT: Diffusion Transformer#

DiT replaces the U-Net backbone with a Vision Transformer for noise prediction, providing better scalability and global context.

Architecture#

DiT Architecture

        graph TD
    A["Noisy image x_t"] --> B["Patchify: (C,H,W) → (N,D)"]
    C["Timestep t"] --> D["Timestep embedding"]
    D --> E["AdaLN modulation"]
    B --> F["DiT Block × depth"]
    E --> F
    F --> G["..."]
    G --> H["Output head: (N,D) → (C,H,W)"]
    H --> I["Predicted noise ε_θ"]
    

Patch Embedding#

The input x_t ℝ^{C×H×W} is split into patches of size P and linearly embedded:

\[x_t \in \mathbb{R}^{C \times H \times W} \;\to\; \text{tokens} \in \mathbb{R}^{(H/P \cdot W/P) \times D}\]

For a 32×32 CIFAR-10 image with P=4: 64 tokens of dimension D.

Timestep Conditioning (AdaLN)#

DiT uses Adaptive Layer Normalization (AdaLN) to condition on the diffusion timestep. The timestep embedding predicts the scale γ and shift β for each block’s layer norm:

\[\text{AdaLN}(h, t) = γ(t) \cdot \text{LayerNorm}(h) + β(t)\]

Variant

Scheme

Parameters

Speed

AdaLN

γ, β per block

Low

Default

AdaLN-Single

Shared γ, β

Very low

Fastest

AdaLN-Diagonal

γ, β + diagonal scaling

Medium

Medium

DiT Block#

\[\begin{split}\begin{aligned} h &\leftarrow h + \text{MSA}(\text{AdaLN}(h, t)) \\ h &\leftarrow h + \text{MLP}(\text{AdaLN}(h, t)) \end{aligned}\end{split}\]

No cross-attention — DiT is typically unconditional or class-conditional via AdaLN.

Training#

from torchwm import DiTConfig, get_dit_config

cfg = get_dit_config(
    DATASET="CIFAR10",
    BATCH=128,
    EPOCHS=100,
    IMG_SIZE=32,
    WIDTH=384,
    DEPTH=6,
)

Sampling (Generation)#

# Start from random noise, iteratively denoise
for t in reversed(range(T)):
    ε = model(x_t, t)  # Predict noise
    x_{t-1} = x_t - sqrt(1-alpha_bar_t) * ε

Classifier-Free Guidance#

For conditional generation:

\[\epsilon_\mathrm{cond} = (1 + w) \cdot \epsilon_\mathrm{model}(\mathbf{x}_t, c) - w \cdot \epsilon_\mathrm{model}(\mathbf{x}_t, \emptyset)\]

where w is guidance weight (typically 1-10).

DIAMOND: Diffusion World Models#

DIAMOND applies diffusion models to world modeling for reinforcement learning. Instead of predicting latent states (Dreamer) or discrete tokens (IRIS), it predicts future observations using a diffusion denoising process.

Architecture#

        graph TD
    A["Past 4 frames"] --> B["Cond. encoder"]
    C["Action a_t"] --> B
    B --> D["Conditioning embedding c"]
    E["Noisy next frame x_t + ε"] --> F["Diffusion UNet"]
    D --> F
    F --> G["Denoised next frame x̂_t"]
    G --> H["Reward/Termination model"]
    H --> I["r̂_t, γ̂_t"]
    G --> J["Actor-Critic"]
    

Key Components#

Component

File

Description

DDPM

diffusion/DDPM.py

Basic DDPM model and scheduler

DiT

diffusion/DiT.py

Diffusion Transformer

DiffusionUNet

diffusion/diamond_diffusion.py

Conditional UNet for DIAMOND

EDMPreconditioner

diffusion/diamond_diffusion.py

EDM-style noise preconditioning

EulerSampler

diffusion/diamond_diffusion.py

Fast ODE-based sampling

RewardTerminationModel

diffusion/reward_termination.py

Reward and termination prediction

ActorCriticNetwork

diffusion/actor_critic.py

Policy and value in imagination

EDM Preconditioning#

DIAMOND uses the EDM (Elucidating Diffusion Model) formulation:

\[D_θ(x, σ) = c_{\text{skip}}(σ) \cdot x + c_{\text{out}}(σ) \cdot F_θ(c_{\text{in}}(σ) \cdot x, c_{\text{noise}}(σ))\]

This ensures the model always receives inputs scaled to unit variance and predicts a scaled output that works well across all noise levels.

Sampling#

DIAMOND uses Euler sampling with very few steps (default 3) for fast inference:

x = noise * sigma_max
for sigma in sigmas[:-1]:
    denoised = model(x, sigma, cond)
    d = (x - denoised) / sigma
    x = x + d * (sigma_next - sigma)
return denoised

Usage in TorchWM#

DiT quick start#

from torchwm import DiTConfig, get_dit_config

cfg = get_dit_config(DATASET="CIFAR10", BATCH=128, EPOCHS=100, IMG_SIZE=32)

DIAMOND quick start#

from world_models.configs.diamond_config import DiamondConfig

cfg = DiamondConfig(preset="small")  # small, medium, large
cfg.game = "Breakout-v5"
cfg.obs_size = 64

DIAMOND CLI#

torchwm train diamond --config world_models/configs/experiments/diamond.yaml \
    preset=small seed=1

Or directly:

python -m world_models.training.train_diamond --game Breakout-v5 --preset small

DIAMOND Training Loop#

for epoch in range(num_epochs):
    # 1. Collect real experience
    for step in range(environment_steps_per_epoch):
        action = select_action(obs)
        next_obs, reward, done = env.step(action)
        buffer.add(obs, action, reward, done, next_obs)
        obs = next_obs

    # 2. Train diffusion world model
    batch = buffer.sample(batch_size)
    loss = train_diffusion_step(batch)

    # 3. Train reward/termination model
    loss_r = train_reward_model(batch)

    # 4. Train actor-critic in imagination
    for step in range(imagination_horizon):
        action = actor(obs, hidden)
        next_obs = diffusion_model(obs, action, cond)
        reward = reward_model(next_obs)
        hidden = update_hidden(hidden, action, next_obs)
    actor_loss, critic_loss = compute_ac_loss(imagination_trajectory)

Config Reference#

DiTConfig#

Field

Default

Description

IMG_SIZE

32

Image resolution

PATCH

4

Patch size for tokenization

WIDTH

384

Transformer embedding dimension

DEPTH

6

Number of DiT blocks

HEADS

6

Number of attention heads

BATCH

128

Training batch size

EPOCHS

100

Training epochs

TIMESTEPS

1000

Diffusion timesteps

BETA_START

1e-4

Noise schedule start

BETA_END

0.02

Noise schedule end

DiamondConfig#

Field

Default

Description

game

Breakout-v5

Atari game name

obs_size

64

Frame resolution

diffusion_channels

[64,64,64,64]

UNet channel multipliers

diffusion_cond_dim

256

Conditioning embedding dim

num_sampling_steps

3

Euler sampling steps

imagination_horizon

15

Actor-critic rollout length

discount_factor

0.985

Discount factor γ

learning_rate

1e-4

Learning rate

num_epochs

1000

Total training epochs

Sampling Methods#

Method

Steps

Quality

Speed

DDPM

1000

Best

Slow

DDIM

50–100

Good

Fast

Euler

3–10

Fair for RL

Very fast

DPM-Solver

10–20

Excellent

Fast

Comparison: DiT vs CNN-Based Diffusion#

Aspect

U-Net (DDPM)

DiT (Transformer)

Architecture

CNN with skip connections

ViT

Global attention

Limited

Full

Scalability

Medium

High

Quality

Good

Slightly better

Compute

Efficient

Higher

Common Pitfalls#

Slow sampling#

DDPM with 1000 steps is too slow for RL training loops.

Fixes:

  • Use Euler sampler with 3–10 steps (DIAMOND default)

  • Use DDIM for higher quality with ~50 steps

Training instability#

Diffusion UNets can produce NaN during training.

Fixes:

  • Enable gradient clipping

  • Use EDM preconditioning

Blurry generations#

Low sampling steps produce blurry results.

Fixes:

  • Increase num_sampling_steps

  • Add classifier-free guidance

  • Train with more diffusion timesteps

References#

  • Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS 2020.

  • Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.

  • Alonso, E., et al. (2024). DIAMOND: Diffusion as a Model of Environment Dreams. ICLR 2024.

  • Song, J., Meng, C., & Ermon, S. (2021). Denoising Diffusion Implicit Models. ICLR 2021.

  • Karras, T., et al. (2022). Elucidating the Design Space of Diffusion-Based Generative Models. NeurIPS 2022.