Source code for world_models.configs.genie_config

from dataclasses import dataclass
from typing import Literal


[docs] @dataclass class GenieConfig: """Configuration for Genie model.""" num_frames: int = 8 image_size: int = 32 in_channels: int = 3 tokenizer_vocab_size: int = 1024 tokenizer_embedding_dim: int = 32 tokenizer_encoder_dim: int = 256 tokenizer_decoder_dim: int = 512 tokenizer_encoder_depth: int = 4 tokenizer_decoder_depth: int = 8 tokenizer_num_heads: int = 8 action_vocab_size: int = 8 action_embedding_dim: int = 32 action_encoder_dim: int = 256 action_encoder_depth: int = 4 action_num_heads: int = 8 action_pooling: Literal["mean", "windowed_attention"] = "mean" window_attention_heads: int = 1 dynamics_dim: int = 512 dynamics_depth: int = 8 dynamics_num_heads: int = 8 batch_size: int = 4 learning_rate: float = 3e-5 weight_decay: float = 1e-4 warmup_steps: int = 5000 max_steps: int = 125000 mask_prob_min: float = 0.5 mask_prob_max: float = 1.0 sample_temperature: float = 2.0 maskgit_steps: int = 25
[docs] @dataclass class GenieSmallConfig: """Small configuration for development/testing.""" num_frames: int = 16 image_size: int = 64 in_channels: int = 3 tokenizer_vocab_size: int = 1024 tokenizer_embedding_dim: int = 32 tokenizer_encoder_dim: int = 256 tokenizer_decoder_dim: int = 512 tokenizer_encoder_depth: int = 4 tokenizer_decoder_depth: int = 8 tokenizer_num_heads: int = 8 action_vocab_size: int = 8 action_embedding_dim: int = 32 action_encoder_dim: int = 512 action_encoder_depth: int = 8 action_num_heads: int = 8 action_pooling: Literal["mean", "windowed_attention"] = "mean" window_attention_heads: int = 1 dynamics_dim: int = 512 dynamics_depth: int = 8 dynamics_num_heads: int = 8 batch_size: int = 2 learning_rate: float = 1e-4 weight_decay: float = 1e-4 warmup_steps: int = 1000 max_steps: int = 50000 mask_prob_min: float = 0.5 mask_prob_max: float = 1.0
[docs] @dataclass class STTransformerConfig: """Configuration for Spatiotemporal Transformer.""" num_frames: int = 16 num_patches_per_frame: int = 256 dim: int = 768 depth: int = 12 num_heads: int = 12 mlp_ratio: float = 4.0 qkv_bias: bool = True drop_rate: float = 0.0 attn_drop_rate: float = 0.0 drop_path_rate: float = 0.0
[docs] @dataclass class VideoTokenizerConfig: """Configuration for Video Tokenizer.""" num_frames: int = 16 image_size: int = 64 in_channels: int = 3 encoder_dim: int = 512 decoder_dim: int = 1024 encoder_depth: int = 12 decoder_depth: int = 20 num_heads: int = 16 patch_size: int = 4 vocab_size: int = 1024 embedding_dim: int = 32 use_ema: bool = False ema_decay: float = 0.99 commitment_weight: float = 0.25
[docs] @dataclass class LatentActionModelConfig: """Configuration for Latent Action Model.""" num_frames: int = 16 image_size: int = 64 in_channels: int = 3 encoder_dim: int = 1024 encoder_depth: int = 20 num_heads: int = 16 patch_size: int = 16 vocab_size: int = 8 embedding_dim: int = 32 commitment_weight: float = 1.0 action_pooling: Literal["mean", "windowed_attention"] = "mean" window_attention_heads: int = 1
[docs] @dataclass class DynamicsModelConfig: """Configuration for Dynamics Model.""" num_frames: int = 16 image_size: int = 64 vocab_size: int = 1024 embedding_dim: int = 32 action_vocab_size: int = 8 dim: int = 5120 depth: int = 48 num_heads: int = 36 patch_size: int = 4 mlp_ratio: float = 4.0 qkv_bias: bool = True drop_rate: float = 0.0 attn_drop_rate: float = 0.0 drop_path_rate: float = 0.0