IRIS: Transformers for Sample-Efficient World Models#

IRIS (Imagination with auto-Regression over an Inner Speech) is an implementation of the paper “Transformers are Sample-Efficient World Models” (Micheli et al., 2023).

Key Idea#

IRIS achieves human-level performance on Atari with only ~2 hours of gameplay (100k environment steps) by learning entirely in the imagination of a world model:

  1. Train world model from real interactions

  2. Generate imagined trajectories in the latent space

  3. Train policy purely on imagined data

Architecture#

Discrete Autoencoder

Encoder CNN 64x64 VQ-VAE 512 vocab 16 tokens Decoder transposed CNN

Autoregressive Transformer

Latent tokens Action token Next latent tokens Reward and termination heads

Actor-Critic in Imagination

Actor CNN and LSTM Critic CNN and LSTM

Key Components#

1. Discrete Autoencoder (VQVAE)#

  • Encoder: 4-layer CNN with residual blocks and self-attention

  • Quantization: Vector Quantizer with EMA updates (512 codebook size)

  • 16 tokens per frame: Reduces sequence length for Transformer efficiency

  • Loss: L1 reconstruction + commitment loss

2. Transformer World Model#

  • GPT-style architecture: 10 layers, 4 attention heads, 256 embedding dim

  • Autoregressive: Predicts next tokens one-by-one

  • Heads: Token prediction, reward prediction, termination prediction

  • Self-supervised training: Cross-entropy on token sequences

3. Actor-Critic#

  • CNN + LSTM: Processes reconstructed frames

  • λ-returns: Balances bias and variance in value estimation

  • REINFORCE: Policy gradient with baseline

  • Entropy bonus: Maintains exploration

Training#

# For full IRIS runs, prefer the TorchWM CLI.
# The CLI wires the trainer, Atari environment, replay buffers, and checkpoints.
torchwm train iris --env ALE/Pong-v5 --device cuda

For custom research code, build an agent directly from the public namespace:

import torch
import torchwm

agent = torchwm.create_model(
    "iris",
    action_size=4,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)

Warm-start Delays#

Component

Start Epoch

Description

Autoencoder

5

Learn frame compression first

Transformer

25

Learn dynamics once tokens are good

Actor-Critic

50

Learn policy in imagination

Key Hyperparameters#

  • Frame size: 64x64

  • Tokens per frame: 16 (from 512 vocabulary)

  • Transformer sequence length: 20 timesteps

  • Imagination horizon: 20 steps

  • Discount (γ): 0.995

  • λ for λ-return: 0.95

Benchmark Results#

Metric

IRIS (ours)

SPR

DrQ

CURL

SimPLe

Mean HNS

1.046

0.616

0.465

0.261

0.332

Superhuman games

10/26

6/26

3/26

2/26

1/26

Usage#

Training#

# Single game
python -m world_models.training.train_iris --game "ALE/Pong-v5"

# Benchmark
python -m world_models.benchmarks.cli --agent iris --game "ALE/Pong-v5" \
  --checkpoint path/to/iris.pt --seeds 5

Configuration#

from torchwm import IRISConfig

config = IRISConfig()

# Autoencoder
config.vocab_size = 512
config.tokens_per_frame = 16

# Transformer
config.transformer_layers = 10
config.transformer_embed_dim = 256

# Training
config.total_epochs = 600
config.env_steps_per_epoch = 200

References#

  • Micheli, Vincent, Eloi Alonso, and François Fleuret. “Transformers are Sample-Efficient World Models.” ICLR 2023.