IRIS: Transformers for Sample-Efficient World Models#

IRIS (Imagination with auto-Regression over an Inner Speech) is an implementation of the paper “Transformers are Sample-Efficient World Models” (Micheli et al., 2023).

Key Idea#

IRIS achieves human-level performance on Atari with only ~2 hours of gameplay (100k environment steps) by learning entirely in the imagination of a world model:

  1. Train world model from real interactions

  2. Generate imagined trajectories in the latent space

  3. Train policy purely on imagined data

Architecture#

┌─────────────────────────────────────────────────────────────────┐
│                        Discrete Autoencoder                      │
│  ┌─────────┐    ┌────────────┐    ┌─────────────┐               │
│  │ Encoder │ -> │   VQVAE    │ -> │   Decoder   │               │
│  │ (CNN)   │    │ (512 vocab)│    │ (Transposed │               │
│  │ 64x64   │    │  16 tokens │    │   CNN)      │               │
│  └─────────┘    └────────────┘    └─────────────┘               │
└─────────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────────┐
│                    Autoregressive Transformer                    │
│  ┌──────────────────────────────────────────────────────────┐   │
│  │  z_t (16 tokens) ──► a_t ──► z_{t+1} (16 tokens)         │   │
│  │         ↓              ↓              ↓                  │   │
│  │      Reward        Termination     Reward...             │   │
│  └──────────────────────────────────────────────────────────┘   │
│  GPT-style: 10 layers, 4 heads, 256 embedding dim              │
└─────────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────────┐
│                     Actor-Critic in Imagination                  │
│  ┌─────────────┐    ┌──────────────────┐                       │
│  │   Actor     │    │     Critic       │                       │
│  │ (CNN+LSTM)  │    │  (CNN+LSTM)      │                       │
│  │             │    │                  │                       │
│  │ λ-return    │    │   MSE loss       │                       │
│  │ REINFORCE   │    │                  │                       │
│  └─────────────┘    └──────────────────┘                       │
└─────────────────────────────────────────────────────────────────┘

Key Components#

1. Discrete Autoencoder (VQVAE)#

  • Encoder: 4-layer CNN with residual blocks and self-attention

  • Quantization: Vector Quantizer with EMA updates (512 codebook size)

  • 16 tokens per frame: Reduces sequence length for Transformer efficiency

  • Loss: L1 reconstruction + commitment loss

2. Transformer World Model#

  • GPT-style architecture: 10 layers, 4 attention heads, 256 embedding dim

  • Autoregressive: Predicts next tokens one-by-one

  • Heads: Token prediction, reward prediction, termination prediction

  • Self-supervised training: Cross-entropy on token sequences

3. Actor-Critic#

  • CNN + LSTM: Processes reconstructed frames

  • λ-returns: Balances bias and variance in value estimation

  • REINFORCE: Policy gradient with baseline

  • Entropy bonus: Maintains exploration

Training#

from world_models.training.train_iris import IRISTrainer
from world_models.configs.iris_config import IRISConfig

trainer = IRISTrainer(
    game="ALE/Pong-v5",
    device="cuda",
    seed=42,
)

trainer.train(total_epochs=600)

Warm-start Delays#

Component

Start Epoch

Description

Autoencoder

5

Learn frame compression first

Transformer

25

Learn dynamics once tokens are good

Actor-Critic

50

Learn policy in imagination

Key Hyperparameters#

  • Frame size: 64x64

  • Tokens per frame: 16 (from 512 vocabulary)

  • Transformer sequence length: 20 timesteps

  • Imagination horizon: 20 steps

  • Discount (γ): 0.995

  • λ for λ-return: 0.95

Benchmark Results#

Metric

IRIS (ours)

SPR

DrQ

CURL

SimPLe

Mean HNS

1.046

0.616

0.465

0.261

0.332

Superhuman games

10/26

6/26

3/26

2/26

1/26

Usage#

Training#

# Single game
python -m world_models.training.train_iris --game "ALE/Pong-v5"

# Benchmark
python -m benchmarks.atari_100k --device cuda --num_seeds 5

Configuration#

from world_models.configs.iris_config import IRISConfig

config = IRISConfig()

# Autoencoder
config.vocab_size = 512
config.tokens_per_frame = 16

# Transformer
config.transformer_layers = 10
config.transformer_embed_dim = 256

# Training
config.total_epochs = 600
config.env_steps_per_epoch = 200

References#

  • Micheli, Vincent, Eloi Alonso, and François Fleuret. “Transformers are Sample-Efficient World Models.” ICLR 2023.