IRIS: Transformers for Sample-Efficient World Models#
IRIS (Imagination with auto-Regression over an Inner Speech) is an implementation of the paper “Transformers are Sample-Efficient World Models” (Micheli et al., 2023).
Key Idea#
IRIS achieves human-level performance on Atari with only ~2 hours of gameplay (100k environment steps) by learning entirely in the imagination of a world model:
Train world model from real interactions
Generate imagined trajectories in the latent space
Train policy purely on imagined data
Architecture#
┌─────────────────────────────────────────────────────────────────┐
│ Discrete Autoencoder │
│ ┌─────────┐ ┌────────────┐ ┌─────────────┐ │
│ │ Encoder │ -> │ VQVAE │ -> │ Decoder │ │
│ │ (CNN) │ │ (512 vocab)│ │ (Transposed │ │
│ │ 64x64 │ │ 16 tokens │ │ CNN) │ │
│ └─────────┘ └────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Autoregressive Transformer │
│ ┌──────────────────────────────────────────────────────────┐ │
│ │ z_t (16 tokens) ──► a_t ──► z_{t+1} (16 tokens) │ │
│ │ ↓ ↓ ↓ │ │
│ │ Reward Termination Reward... │ │
│ └──────────────────────────────────────────────────────────┘ │
│ GPT-style: 10 layers, 4 heads, 256 embedding dim │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Actor-Critic in Imagination │
│ ┌─────────────┐ ┌──────────────────┐ │
│ │ Actor │ │ Critic │ │
│ │ (CNN+LSTM) │ │ (CNN+LSTM) │ │
│ │ │ │ │ │
│ │ λ-return │ │ MSE loss │ │
│ │ REINFORCE │ │ │ │
│ └─────────────┘ └──────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
Key Components#
1. Discrete Autoencoder (VQVAE)#
Encoder: 4-layer CNN with residual blocks and self-attention
Quantization: Vector Quantizer with EMA updates (512 codebook size)
16 tokens per frame: Reduces sequence length for Transformer efficiency
Loss: L1 reconstruction + commitment loss
2. Transformer World Model#
GPT-style architecture: 10 layers, 4 attention heads, 256 embedding dim
Autoregressive: Predicts next tokens one-by-one
Heads: Token prediction, reward prediction, termination prediction
Self-supervised training: Cross-entropy on token sequences
3. Actor-Critic#
CNN + LSTM: Processes reconstructed frames
λ-returns: Balances bias and variance in value estimation
REINFORCE: Policy gradient with baseline
Entropy bonus: Maintains exploration
Training#
from world_models.training.train_iris import IRISTrainer
from world_models.configs.iris_config import IRISConfig
trainer = IRISTrainer(
game="ALE/Pong-v5",
device="cuda",
seed=42,
)
trainer.train(total_epochs=600)
Warm-start Delays#
Component |
Start Epoch |
Description |
|---|---|---|
Autoencoder |
5 |
Learn frame compression first |
Transformer |
25 |
Learn dynamics once tokens are good |
Actor-Critic |
50 |
Learn policy in imagination |
Key Hyperparameters#
Frame size: 64x64
Tokens per frame: 16 (from 512 vocabulary)
Transformer sequence length: 20 timesteps
Imagination horizon: 20 steps
Discount (γ): 0.995
λ for λ-return: 0.95
Benchmark Results#
Metric |
IRIS (ours) |
SPR |
DrQ |
CURL |
SimPLe |
|---|---|---|---|---|---|
Mean HNS |
1.046 |
0.616 |
0.465 |
0.261 |
0.332 |
Superhuman games |
10/26 |
6/26 |
3/26 |
2/26 |
1/26 |
Usage#
Training#
# Single game
python -m world_models.training.train_iris --game "ALE/Pong-v5"
# Benchmark
python -m benchmarks.atari_100k --device cuda --num_seeds 5
Configuration#
from world_models.configs.iris_config import IRISConfig
config = IRISConfig()
# Autoencoder
config.vocab_size = 512
config.tokens_per_frame = 16
# Transformer
config.transformer_layers = 10
config.transformer_embed_dim = 256
# Training
config.total_epochs = 600
config.env_steps_per_epoch = 200
References#
Micheli, Vincent, Eloi Alonso, and François Fleuret. “Transformers are Sample-Efficient World Models.” ICLR 2023.