# IRIS: Transformers for Sample-Efficient World Models
IRIS (Imagination with auto-Regression over an Inner Speech) is an implementation of the paper
"Transformers are Sample-Efficient World Models" (Micheli et al., 2023).
## Key Idea
IRIS achieves **human-level performance on Atari with only ~2 hours of gameplay** (100k environment steps)
by learning entirely in the imagination of a world model:
1. **Train world model** from real interactions
2. **Generate imagined trajectories** in the latent space
3. **Train policy** purely on imagined data
## Architecture
Discrete Autoencoder
Encoder CNN 64x64
→
VQ-VAE 512 vocab 16 tokens
→
Decoder transposed CNN
Autoregressive Transformer
Latent tokens
→
Action token
→
Next latent tokens
→
Reward and termination heads
Actor-Critic in Imagination
Actor CNN and LSTM
Critic CNN and LSTM
## Key Components
### 1. Discrete Autoencoder (VQVAE)
- **Encoder**: 4-layer CNN with residual blocks and self-attention
- **Quantization**: Vector Quantizer with EMA updates (512 codebook size)
- **16 tokens per frame**: Reduces sequence length for Transformer efficiency
- **Loss**: L1 reconstruction + commitment loss
### 2. Transformer World Model
- **GPT-style architecture**: 10 layers, 4 attention heads, 256 embedding dim
- **Autoregressive**: Predicts next tokens one-by-one
- **Heads**: Token prediction, reward prediction, termination prediction
- **Self-supervised training**: Cross-entropy on token sequences
### 3. Actor-Critic
- **CNN + LSTM**: Processes reconstructed frames
- **λ-returns**: Balances bias and variance in value estimation
- **REINFORCE**: Policy gradient with baseline
- **Entropy bonus**: Maintains exploration
## Training
```bash
# For full IRIS runs, prefer the TorchWM CLI.
# The CLI wires the trainer, Atari environment, replay buffers, and checkpoints.
torchwm train iris --env ALE/Pong-v5 --device cuda
```
For custom research code, build an agent directly from the public namespace:
```python :class: thebe
import torch
import torchwm
agent = torchwm.create_model(
"iris",
action_size=4,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
```
### Warm-start Delays
| Component | Start Epoch | Description |
|-----------|-------------|-------------|
| Autoencoder | 5 | Learn frame compression first |
| Transformer | 25 | Learn dynamics once tokens are good |
| Actor-Critic | 50 | Learn policy in imagination |
### Key Hyperparameters
- **Frame size**: 64x64
- **Tokens per frame**: 16 (from 512 vocabulary)
- **Transformer sequence length**: 20 timesteps
- **Imagination horizon**: 20 steps
- **Discount (γ)**: 0.995
- **λ for λ-return**: 0.95
## Benchmark Results
| Metric | IRIS (ours) | SPR | DrQ | CURL | SimPLe |
|--------|-------------|-----|-----|------|--------|
| Mean HNS | **1.046** | 0.616 | 0.465 | 0.261 | 0.332 |
| Superhuman games | **10/26** | 6/26 | 3/26 | 2/26 | 1/26 |
## Usage
### Training
```bash
# Single game
python -m world_models.training.train_iris --game "ALE/Pong-v5"
# Benchmark
python -m world_models.benchmarks.cli --agent iris --game "ALE/Pong-v5" \
--checkpoint path/to/iris.pt --seeds 5
```
### Configuration
```python :class: thebe
from torchwm import IRISConfig
config = IRISConfig()
# Autoencoder
config.vocab_size = 512
config.tokens_per_frame = 16
# Transformer
config.transformer_layers = 10
config.transformer_embed_dim = 256
# Training
config.total_epochs = 600
config.env_steps_per_epoch = 200
```
## References
- Micheli, Vincent, Eloi Alonso, and François Fleuret. "Transformers are Sample-Efficient World Models." ICLR 2023.