# DiT: Diffusion Transformer and Diffusion Models
This page covers the diffusion-based models in TorchWM: DDPM for image generation,
DiT for scalable transformer-based diffusion, and DIAMOND for diffusion world models
in reinforcement learning.
Based on papers:
- [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) (Ho et al., 2020)
- [Scalable Diffusion Models with Transformers](https://arxiv.org/abs/2212.09748) (Peebles & Xie, 2023)
- [DIAMOND: Diffusion as a Model of Environment Dreams](https://arxiv.org/abs/2403.05187) (Alonso et al., 2024)
```{contents} Contents
:depth: 3
```
## Overview
Diffusion models learn to generate data by reversing a gradual noising process.
In TorchWM, diffusion models serve two purposes:
1. **Image/video generation** (DDPM, DiT): High-quality unconditional or conditional sample generation.
2. **World models for RL** (DIAMOND): Diffusion-based dynamics models that predict future observations.
```{mermaid}
graph LR
A["Data x₀"] --> B["Forward: q(x_t|x₀)"]
B --> C["..."]
C --> D["x_T ~ N(0, I)"]
D --> E["Reverse: p_θ(x_{t-1}|x_t)"]
E --> F["..."]
F --> G["Generated x₀"]
```
## DDPM: Denoising Diffusion Probabilistic Models
### Forward Process
The forward (diffusion) process gradually adds Gaussian noise to data over `T`
timesteps according to a fixed variance schedule:
```{math}
q(x_t | x_{t-1}) = \mathcal{N}\left(\sqrt{1 - \beta_t}\, x_{t-1},\; \beta_t I\right)
```
We can sample `x_t` directly from `x_0`:
```{math}
q(x_t | x_0) = \mathcal{N}\left(\sqrt{\bar{\alpha}_t}\, x_0,\; (1 - \bar{\alpha}_t) I\right)
```
where `α_t = 1 - β_t` and `\bar{α}_t = ∏_{s=1}^{t} α_s`. Using the
reparameterization trick:
```{math}
x_t = \sqrt{\bar{\alpha}_t}\, x_0 + \sqrt{1 - \bar{\alpha}_t}\, \epsilon, \quad
\epsilon \sim \mathcal{N}(0, I)
```
### Reverse Process
The reverse process learns to denoise. Starting from pure noise `x_T ∼ N(0, I)`:
```{math}
p_θ(x_{t-1} | x_t) = \mathcal{N}\left(
\mu_θ(x_t, t),\; \sigma_t^2 I
\right)
```
### Training Objective
The simplified DDPM loss trains the model to predict the noise `ε` at each timestep:
```{math}
\mathcal{L}_{\text{DDPM}} = \mathbb{E}_{t, x_0, \epsilon}
\left[ \left\| \epsilon - \epsilon_θ\left(\sqrt{\bar{\alpha}_t}\, x_0 +
\sqrt{1 - \bar{\alpha}_t}\, \epsilon,\; t\right) \right\|^2 \right]
```
```python
x0 = batch # clean images
t = randint(0, T) # random timestep
eps = randn_like(x0) # random noise
xt = sqrt(alpha_bar[t]) * x0 + sqrt(1 - alpha_bar[t]) * eps
eps_pred = model(xt, t) # predict noise
loss = mse(eps_pred, eps) # simple noise-prediction loss
loss.backward()
```
### Sampling
```python
x = torch.randn(shape) # pure noise
for t in reversed(range(T)):
eps_pred = model(x, t) # predict noise
x = (x - sqrt(1 - alpha_bar[t]) * eps_pred) / sqrt(alpha_bar[t])
if t > 0:
x += sigma[t] * torch.randn_like(x)
return x # generated image
```
## DiT: Diffusion Transformer
DiT replaces the U-Net backbone with a Vision Transformer for noise prediction,
providing better scalability and global context.
### Architecture
```{mermaid}
graph TD
A["Noisy image x_t"] --> B["Patchify: (C,H,W) → (N,D)"]
C["Timestep t"] --> D["Timestep embedding"]
D --> E["AdaLN modulation"]
B --> F["DiT Block × depth"]
E --> F
F --> G["..."]
G --> H["Output head: (N,D) → (C,H,W)"]
H --> I["Predicted noise ε_θ"]
```
### Patch Embedding
The input `x_t ∈ ℝ^{C×H×W}` is split into patches of size `P` and linearly embedded:
```{math}
x_t \in \mathbb{R}^{C \times H \times W}
\;\to\; \text{tokens} \in \mathbb{R}^{(H/P \cdot W/P) \times D}
```
For a 32×32 CIFAR-10 image with `P=4`: 64 tokens of dimension `D`.
### Timestep Conditioning (AdaLN)
DiT uses **Adaptive Layer Normalization** (AdaLN) to condition on the diffusion
timestep. The timestep embedding predicts the scale `γ` and shift `β` for each
block's layer norm:
```{math}
\text{AdaLN}(h, t) = γ(t) \cdot \text{LayerNorm}(h) + β(t)
```
| Variant | Scheme | Parameters | Speed |
|---------|--------|------------|-------|
| **AdaLN** | γ, β per block | Low | Default |
| **AdaLN-Single** | Shared γ, β | Very low | Fastest |
| **AdaLN-Diagonal** | γ, β + diagonal scaling | Medium | Medium |
### DiT Block
```{math}
\begin{aligned}
h &\leftarrow h + \text{MSA}(\text{AdaLN}(h, t)) \\
h &\leftarrow h + \text{MLP}(\text{AdaLN}(h, t))
\end{aligned}
```
No cross-attention — DiT is typically unconditional or class-conditional via AdaLN.
### Training
```python :class: thebe
from torchwm import DiTConfig, get_dit_config
cfg = get_dit_config(
DATASET="CIFAR10",
BATCH=128,
EPOCHS=100,
IMG_SIZE=32,
WIDTH=384,
DEPTH=6,
)
```
### Sampling (Generation)
```python :class: thebe
# Start from random noise, iteratively denoise
for t in reversed(range(T)):
ε = model(x_t, t) # Predict noise
x_{t-1} = x_t - sqrt(1-alpha_bar_t) * ε
```
### Classifier-Free Guidance
For conditional generation:
```{math}
\epsilon_\mathrm{cond} = (1 + w) \cdot \epsilon_\mathrm{model}(\mathbf{x}_t, c)
- w \cdot \epsilon_\mathrm{model}(\mathbf{x}_t, \emptyset)
```
where `w` is guidance weight (typically 1-10).
## DIAMOND: Diffusion World Models
DIAMOND applies diffusion models to world modeling for reinforcement learning.
Instead of predicting latent states (Dreamer) or discrete tokens (IRIS), it
predicts future **observations** using a diffusion denoising process.
### Architecture
```{mermaid}
graph TD
A["Past 4 frames"] --> B["Cond. encoder"]
C["Action a_t"] --> B
B --> D["Conditioning embedding c"]
E["Noisy next frame x_t + ε"] --> F["Diffusion UNet"]
D --> F
F --> G["Denoised next frame x̂_t"]
G --> H["Reward/Termination model"]
H --> I["r̂_t, γ̂_t"]
G --> J["Actor-Critic"]
```
### Key Components
| Component | File | Description |
|-----------|------|-------------|
| `DDPM` | `diffusion/DDPM.py` | Basic DDPM model and scheduler |
| `DiT` | `diffusion/DiT.py` | Diffusion Transformer |
| `DiffusionUNet` | `diffusion/diamond_diffusion.py` | Conditional UNet for DIAMOND |
| `EDMPreconditioner` | `diffusion/diamond_diffusion.py` | EDM-style noise preconditioning |
| `EulerSampler` | `diffusion/diamond_diffusion.py` | Fast ODE-based sampling |
| `RewardTerminationModel` | `diffusion/reward_termination.py` | Reward and termination prediction |
| `ActorCriticNetwork` | `diffusion/actor_critic.py` | Policy and value in imagination |
### EDM Preconditioning
DIAMOND uses the EDM (Elucidating Diffusion Model) formulation:
```{math}
D_θ(x, σ) = c_{\text{skip}}(σ) \cdot x + c_{\text{out}}(σ) \cdot
F_θ(c_{\text{in}}(σ) \cdot x, c_{\text{noise}}(σ))
```
This ensures the model always receives inputs scaled to unit variance and predicts
a scaled output that works well across all noise levels.
### Sampling
DIAMOND uses **Euler sampling** with very few steps (default 3) for fast inference:
```python
x = noise * sigma_max
for sigma in sigmas[:-1]:
denoised = model(x, sigma, cond)
d = (x - denoised) / sigma
x = x + d * (sigma_next - sigma)
return denoised
```
## Usage in TorchWM
### DiT quick start
```python
from torchwm import DiTConfig, get_dit_config
cfg = get_dit_config(DATASET="CIFAR10", BATCH=128, EPOCHS=100, IMG_SIZE=32)
```
### DIAMOND quick start
```python
from world_models.configs.diamond_config import DiamondConfig
cfg = DiamondConfig(preset="small") # small, medium, large
cfg.game = "Breakout-v5"
cfg.obs_size = 64
```
### DIAMOND CLI
```bash
torchwm train diamond --config world_models/configs/experiments/diamond.yaml \
preset=small seed=1
```
Or directly:
```bash
python -m world_models.training.train_diamond --game Breakout-v5 --preset small
```
### DIAMOND Training Loop
```
for epoch in range(num_epochs):
# 1. Collect real experience
for step in range(environment_steps_per_epoch):
action = select_action(obs)
next_obs, reward, done = env.step(action)
buffer.add(obs, action, reward, done, next_obs)
obs = next_obs
# 2. Train diffusion world model
batch = buffer.sample(batch_size)
loss = train_diffusion_step(batch)
# 3. Train reward/termination model
loss_r = train_reward_model(batch)
# 4. Train actor-critic in imagination
for step in range(imagination_horizon):
action = actor(obs, hidden)
next_obs = diffusion_model(obs, action, cond)
reward = reward_model(next_obs)
hidden = update_hidden(hidden, action, next_obs)
actor_loss, critic_loss = compute_ac_loss(imagination_trajectory)
```
## Config Reference
### DiTConfig
| Field | Default | Description |
|-------|---------|-------------|
| `IMG_SIZE` | 32 | Image resolution |
| `PATCH` | 4 | Patch size for tokenization |
| `WIDTH` | 384 | Transformer embedding dimension |
| `DEPTH` | 6 | Number of DiT blocks |
| `HEADS` | 6 | Number of attention heads |
| `BATCH` | 128 | Training batch size |
| `EPOCHS` | 100 | Training epochs |
| `TIMESTEPS` | 1000 | Diffusion timesteps |
| `BETA_START` | 1e-4 | Noise schedule start |
| `BETA_END` | 0.02 | Noise schedule end |
### DiamondConfig
| Field | Default | Description |
|-------|---------|-------------|
| `game` | `Breakout-v5` | Atari game name |
| `obs_size` | 64 | Frame resolution |
| `diffusion_channels` | [64,64,64,64] | UNet channel multipliers |
| `diffusion_cond_dim` | 256 | Conditioning embedding dim |
| `num_sampling_steps` | 3 | Euler sampling steps |
| `imagination_horizon` | 15 | Actor-critic rollout length |
| `discount_factor` | 0.985 | Discount factor γ |
| `learning_rate` | 1e-4 | Learning rate |
| `num_epochs` | 1000 | Total training epochs |
## Sampling Methods
| Method | Steps | Quality | Speed |
|--------|-------|---------|-------|
| DDPM | 1000 | Best | Slow |
| DDIM | 50–100 | Good | Fast |
| Euler | 3–10 | Fair for RL | Very fast |
| DPM-Solver | 10–20 | Excellent | Fast |
## Comparison: DiT vs CNN-Based Diffusion
| Aspect | U-Net (DDPM) | DiT (Transformer) |
|--------|-------------|-------------------|
| Architecture | CNN with skip connections | ViT |
| Global attention | Limited | Full |
| Scalability | Medium | High |
| Quality | Good | Slightly better |
| Compute | Efficient | Higher |
## Common Pitfalls
### Slow sampling
DDPM with 1000 steps is too slow for RL training loops.
**Fixes:**
- Use Euler sampler with 3–10 steps (DIAMOND default)
- Use DDIM for higher quality with ~50 steps
### Training instability
Diffusion UNets can produce NaN during training.
**Fixes:**
- Enable gradient clipping
- Use EDM preconditioning
### Blurry generations
Low sampling steps produce blurry results.
**Fixes:**
- Increase `num_sampling_steps`
- Add classifier-free guidance
- Train with more diffusion timesteps
## References
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. *NeurIPS 2020.*
- Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. *ICCV 2023.*
- Alonso, E., et al. (2024). DIAMOND: Diffusion as a Model of Environment Dreams. *ICLR 2024.*
- Song, J., Meng, C., & Ermon, S. (2021). Denoising Diffusion Implicit Models. *ICLR 2021.*
- Karras, T., et al. (2022). Elucidating the Design Space of Diffusion-Based Generative Models. *NeurIPS 2022.*