Source code for world_models.configs.genie_config
from dataclasses import dataclass
from typing import Literal
[docs]
@dataclass
class GenieConfig:
"""Configuration for Genie model."""
num_frames: int = 8
image_size: int = 32
in_channels: int = 3
tokenizer_vocab_size: int = 1024
tokenizer_embedding_dim: int = 32
tokenizer_encoder_dim: int = 256
tokenizer_decoder_dim: int = 512
tokenizer_encoder_depth: int = 4
tokenizer_decoder_depth: int = 8
tokenizer_num_heads: int = 8
action_vocab_size: int = 8
action_embedding_dim: int = 32
action_encoder_dim: int = 256
action_encoder_depth: int = 4
action_num_heads: int = 8
action_pooling: Literal["mean", "windowed_attention"] = "mean"
window_attention_heads: int = 1
dynamics_dim: int = 512
dynamics_depth: int = 8
dynamics_num_heads: int = 8
batch_size: int = 4
learning_rate: float = 3e-5
weight_decay: float = 1e-4
warmup_steps: int = 5000
max_steps: int = 125000
mask_prob_min: float = 0.5
mask_prob_max: float = 1.0
sample_temperature: float = 2.0
maskgit_steps: int = 25
[docs]
@dataclass
class GenieSmallConfig:
"""Small configuration for development/testing."""
num_frames: int = 16
image_size: int = 64
in_channels: int = 3
tokenizer_vocab_size: int = 1024
tokenizer_embedding_dim: int = 32
tokenizer_encoder_dim: int = 256
tokenizer_decoder_dim: int = 512
tokenizer_encoder_depth: int = 4
tokenizer_decoder_depth: int = 8
tokenizer_num_heads: int = 8
action_vocab_size: int = 8
action_embedding_dim: int = 32
action_encoder_dim: int = 512
action_encoder_depth: int = 8
action_num_heads: int = 8
action_pooling: Literal["mean", "windowed_attention"] = "mean"
window_attention_heads: int = 1
dynamics_dim: int = 512
dynamics_depth: int = 8
dynamics_num_heads: int = 8
batch_size: int = 2
learning_rate: float = 1e-4
weight_decay: float = 1e-4
warmup_steps: int = 1000
max_steps: int = 50000
mask_prob_min: float = 0.5
mask_prob_max: float = 1.0
[docs]
@dataclass
class VideoTokenizerConfig:
"""Configuration for Video Tokenizer."""
num_frames: int = 16
image_size: int = 64
in_channels: int = 3
encoder_dim: int = 512
decoder_dim: int = 1024
encoder_depth: int = 12
decoder_depth: int = 20
num_heads: int = 16
patch_size: int = 4
vocab_size: int = 1024
embedding_dim: int = 32
use_ema: bool = False
ema_decay: float = 0.99
commitment_weight: float = 0.25
[docs]
@dataclass
class LatentActionModelConfig:
"""Configuration for Latent Action Model."""
num_frames: int = 16
image_size: int = 64
in_channels: int = 3
encoder_dim: int = 1024
encoder_depth: int = 20
num_heads: int = 16
patch_size: int = 16
vocab_size: int = 8
embedding_dim: int = 32
commitment_weight: float = 1.0
action_pooling: Literal["mean", "windowed_attention"] = "mean"
window_attention_heads: int = 1
[docs]
@dataclass
class DynamicsModelConfig:
"""Configuration for Dynamics Model."""
num_frames: int = 16
image_size: int = 64
vocab_size: int = 1024
embedding_dim: int = 32
action_vocab_size: int = 8
dim: int = 5120
depth: int = 48
num_heads: int = 36
patch_size: int = 4
mlp_ratio: float = 4.0
qkv_bias: bool = True
drop_rate: float = 0.0
attn_drop_rate: float = 0.0
drop_path_rate: float = 0.0