# IRIS: Transformers for Sample-Efficient World Models
IRIS (Imagination with auto-Regression over an Inner Speech) is an implementation of the paper
"Transformers are Sample-Efficient World Models" (Micheli et al., 2023).
```{contents} Contents
:depth: 3
```
## Overview
IRIS achieves **human-level performance on Atari with only ~2 hours of gameplay** (100k environment steps)
by learning entirely in the imagination of a world model:
1. **Train world model** from real interactions
2. **Generate imagined trajectories** in the latent space
3. **Train policy** purely on imagined data
## Architecture
### High-level diagram
Discrete Autoencoder
Encoder CNN 64x64
→
VQ-VAE 512 vocab 16 tokens
→
Decoder transposed CNN
Autoregressive Transformer
Latent tokens
→
Action token
→
Next latent tokens
→
Reward and termination heads
Actor-Critic in Imagination
Actor CNN and LSTM
Critic CNN and LSTM
### VQ-VAE: Discrete Autoencoder
Both IRIS and Genie use Vector Quantized Variational Autoencoders (VQ-VAE) to convert
continuous visual observations into discrete token sequences.
```{mermaid}
graph LR
A["Image x"] --> B["CNN Encoder"]
B --> C["Continuous z_e(x)"]
C --> D["Vector Quantization"]
E["Codebook {e_k}"] --> D
D --> F["Discrete indices + z_q(x)"]
F --> G["CNN Decoder"]
G --> H["Reconstructed x̂"]
```
**Quantization:**
The encoder output `z_e(x)` is mapped to the nearest codebook vector:
```{math}
z_q(x) = e_k, \quad \text{where } k = \arg\min_j \|z_e(x) - e_j\|_2
```
**VQ-VAE Loss:**
```{math}
\mathcal{L}_{\text{VQ}} =
\underbrace{\|\hat{x} - x\|^2}_{\text{reconstruction}}
+ \underbrace{\|\text{sg}[z_e(x)] - e_k\|^2}_{\text{codebook loss}}
+ \beta \cdot \underbrace{\|z_e(x) - \text{sg}[e_k]\|^2}_{\text{commitment loss}}
```
IRIS uses EMA (Exponential Moving Average) for codebook updates instead
of the codebook loss, producing more stable training.
### Discrete Autoencoder Architecture
The encoder maps a 64×64 RGB frame to **16 tokens** from a **512-entry** codebook:
```
Input: (3, 64, 64)
└─ Conv2D(3, 64, 4, stride 2) → (64, 31, 31)
└─ ResBlock(64, 64)
└─ Conv2D(64, 64, 4, stride 2) → (64, 14, 14)
└─ ResBlock(64, 64)
└─ Conv2D(64, 64, 4, stride 2) → (64, 6, 6)
└─ ResBlock(64, 64)
└─ Conv2D(64, 64, 4, stride 2) → (64, 2, 2)
└─ VQ layer → (16,) discrete indices
Output: 16 token indices (each ∈ {0, ..., 511})
```
### Transformer World Model
The transformer is a GPT-style autoregressive model:
```
Params:
- vocab_size: 512 (visual) + action_size + 2 (reward/terminal tokens)
- embed_dim: 256
- num_layers: 10
- num_heads: 4
- seq_length: 20 timesteps × 16 tokens = 320 tokens
Architecture:
Token Embedding → Positional Embedding → Transformer Blocks → LM Head
```
**Input sequence format** (per timestep):
```
[zₜ_0, zₜ_1, ..., zₜ_15 | aₜ | rₜ | γₜ] → [zₜ₊₁_0, zₜ₊₁_1, ..., zₜ₊₁_15]
```
### Actor-Critic
- **CNN + LSTM**: Processes reconstructed frames
- **λ-returns**: Balances bias and variance in value estimation
- **REINFORCE**: Policy gradient with baseline
- **Entropy bonus**: Maintains exploration
### Imagination Rollout
```python
# Imagine H steps: sample tokens autoregressively, decode to frames, feed to actor-critic
for h in range(imagination_horizon):
tokens = transformer.generate(prev_tokens, action)
frame = autoencoder.decode(tokens) # decode to pixels
action = actor(frame, hidden_state) # policy
reward = transformer.reward_head(tokens) # predicted reward
hidden_state = lstm(hidden_state, action, tokens)
```
## Training
### Staged training schedule
| Component | Start Epoch | Description |
|-----------|-------------|-------------|
| Autoencoder | 5 | Learn frame compression first |
| Transformer | 15 | Learn dynamics once tokens are good |
| Actor-Critic | 35 | Learn policy in imagination |
### Key Hyperparameters
- **Frame size**: 64x64
- **Tokens per frame**: 16 (from 512 vocabulary)
- **Transformer sequence length**: 20 timesteps
- **Imagination horizon**: 20 steps
- **Discount (γ)**: 0.995
- **λ for λ-return**: 0.95
## Usage in TorchWM
### Quick start
```python
import torch
import torchwm
agent = torchwm.create_model(
"iris",
action_size=4,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
```
### Using config directly
```python
from torchwm import IRISConfig
config = IRISConfig()
# Autoencoder
config.vocab_size = 512
config.tokens_per_frame = 16
# Transformer
config.transformer_layers = 10
config.transformer_embed_dim = 256
# Training
config.total_epochs = 600
config.env_steps_per_epoch = 200
config.env = "ALE/Pong-v5"
```
### CLI
```bash
torchwm train iris --env ALE/Pong-v5 --device cuda
```
For custom research code:
```bash
python -m world_models.training.train_iris --game "ALE/Pong-v5"
```
## Config Reference
```python
from world_models.configs.iris_config import IRISConfig
config = IRISConfig()
# Autoencoder
config.vocab_size = 512 # Codebook size
config.tokens_per_frame = 16 # Tokens per frame
config.token_embedding_dim = 512
config.encoder_channels = 64
# Transformer
config.transformer_layers = 10
config.transformer_embed_dim = 256
config.transformer_heads = 4
config.transformer_timesteps = 20
# Training schedule
config.start_autoencoder_after = 5
config.start_transformer_after = 15
config.start_actor_critic_after = 35
config.total_epochs = 600
# Atari 100k
config.atari_100k = True
config.env = "ALE/Pong-v5"
config.max_env_steps = 100000
```
## Benchmark Results
| Metric | IRIS (ours) | SPR | DrQ | CURL | SimPLe |
|--------|-------------|-----|-----|------|--------|
| Mean HNS | **1.046** | 0.616 | 0.465 | 0.261 | 0.332 |
| Superhuman games | **10/26** | 6/26 | 3/26 | 2/26 | 1/26 |
## Common Pitfalls
### Codebook collapse
Most codebook entries go unused.
**Fixes:**
- Use EMA codebook updates (default in IRIS)
- Lower commitment loss weight `β`
- Add codebook reset: re-initialize unused codes
### Transformer memory
Sequence length: 16 × 20 = 320 tokens.
**Fixes:**
- Use gradient checkpointing
- Reduce context length
### Slow autoregressive generation
AR token generation is O(tokens) sequential.
**Fixes:**
- Use KV caching for transformer inference
- Reduce the number of imagination steps
## References
- Micheli, V., Alonso, E., & Fleuret, F. (2023). Transformers are Sample-Efficient World Models. *ICLR 2023.*
- Van Den Oord, A., & Vinyals, O. (2017). Neural Discrete Representation Learning. *NeurIPS 2017.*