Configs Reference#

This page documents all configuration classes in TorchWM.

DreamerConfig#

Configuration for Dreamer agent training.

@dataclass
class DreamerConfig:
    # Environment
    env_backend: str = "dmc"
    env: str = "walker-walk"
    env_instance: Optional[object] = None
    image_size: Tuple[int, int] = (64, 64)
    gym_render_mode: str = "rgb_array"

    # Unity ML-Agents
    unity_file_name: Optional[str] = None
    unity_behavior_name: Optional[str] = None
    unity_worker_id: int = 0
    unity_base_port: int = 5005
    unity_no_graphics: bool = True
    unity_time_scale: float = 20.0
    unity_quality_level: int = 1

    # Training
    algo: str = "Dreamerv1"
    exp_name: str = "lr1e-3"
    train: bool = True
    evaluate: bool = False
    seed: int = 1
    no_gpu: bool = False
    max_episode_length: int = 1000
    buffer_size: int = 800000
    time_limit: int = 1000

    # Model
    cnn_activation_function: str = "relu"
    dense_activation_function: str = "elu"
    obs_embed_size: int = 1024
    num_units: int = 400
    deter_size: int = 200
    stoch_size: int = 30
    action_repeat: int = 2
    action_noise: float = 0.3

    # Learning
    model_lr: float = 6e-4
    actor_lr: float = 8e-5
    value_lr: float = 8e-5
    total_steps: int = int(5e6)
    seed_steps: int = 5000
    update_steps: int = 100
    collect_steps: int = 1000
    batch_size: int = 50
    train_seq_len: int = 50
    imagine_horizon: int = 15

    # Loss
    free_nats: float = 3.0
    kl_scale: float = 1.0
    discount: float = 0.99
    td_lambda: float = 0.95
    kl_loss_coeff: float = 1.0
    kl_alpha: float = 0.8
    disc_loss_coeff: float = 10.0

    # Optimization
    adam_epsilon: float = 1e-7
    grad_clip_norm: float = 100.0

    # Exploration
    expl_amount: float = 0.3
    expl_decay: float = 0.0
    expl_min: float = 0.0

    # Logging
    log_every: int = 1e4
    log_scalars: bool = True
    log_images: bool = True
    log_videos: bool = False
    save_every: int = 1e5
    eval_every: int = 1e4
    eval_episodes: int = 10

    # WandB
    enable_wandb: bool = False
    wandb_api_key: str = ""
    wandb_project: str = "torchwm"
    wandb_entity: str = ""
    log_dir: str = "runs"

    # Operator parameters
    operator_image_size: int = 64
    operator_action_dim: int = 6

JEPAConfig#

Configuration for JEPA training.

@dataclass
class JEPAConfig:
    # Meta
    use_bfloat16: bool = False
    model_name: str = "vit_base"
    load_checkpoint: bool = False
    read_checkpoint: Optional[str] = None
    copy_data: bool = False
    pred_depth: int = 6
    pred_emb_dim: int = 384

    # Data
    dataset: str = "imagenet"
    val_split: Optional[float] = None
    use_gaussian_blur: bool = True
    use_horizontal_flip: bool = True
    use_color_distortion: bool = True
    color_jitter_strength: float = 0.5
    batch_size: int = 64
    pin_mem: bool = True
    num_workers: int = 8
    root_path: str = os.environ.get("IMAGENET_ROOT", "/data/imagenet")
    image_folder: str = "train"
    crop_size: int = 224
    crop_scale: Tuple[float, float] = (0.67, 1.0)
    download: bool = False

    # Mask
    allow_overlap: bool = False
    patch_size: int = 16
    num_enc_masks: int = 1
    min_keep: int = 4
    enc_mask_scale: Tuple[float, float] = (0.15, 0.2)
    num_pred_masks: int = 1
    pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
    aspect_ratio: Tuple[float, float] = (0.75, 1.5)

    # Optimization
    ema: Tuple[float, float] = (0.996, 1.0)
    clip_grad: float = 1.0
    weight_decay: float = 0.0
    epochs: int = 100
    warmup_epochs: int = 40
    start_epoch: int = 0
    lr: float = 1.5e-4
    min_lr: float = 1e-6
    accum_iter: int = 1
    output_dir: str = "./output"
    log_dir: str = "./output"
    device: str = "cuda"
    seed: int = 0
    resume: str = ""
    auto_resume: bool = True
    save_ckpt: bool = True
    save_ckpt_freq: int = 20
    save_ckpt_num: int = 3
    start_ckpt: str = ""
    debug: bool = False
    num_debug: int = 10
    wandb: bool = True
    wandb_project: str = "jepa"
    wandb_entity: str = ""
    enable_sweep: bool = False
    sweep_config: Dict[str, Any] = {}
    dist_eval: bool = True
    no_env: bool = False
    eval: bool = False
    disable_rel_pos_bias: bool = True
    disable_masking: bool = False
    enable_deepspeed: bool = False
    gradient_checkpointing: bool = True

    # Operator parameters
    operator_image_size: int = 224
    operator_patch_size: int = 16
    operator_mask_ratio: float = 0.75

IRISConfig#

Configuration for IRIS training.

@dataclass
class IRISConfig:
    # Discrete Autoencoder (VQVAE)
    frame_height: int = 64
    frame_width: int = 64
    frame_channels: int = 3
    vocab_size: int = 512
    tokens_per_frame: int = 16
    token_embedding_dim: int = 512
    encoder_channels: int = 64
    encoder_layers: int = 4
    encoder_residual_blocks: int = 2
    decoder_depth: int = 32
    reconstruction_weight: float = 1.0
    commitment_weight: float = 0.25
    perceptual_weight: float = 1.0

    # Transformer (World Model)
    transformer_timesteps: int = 20
    transformer_embed_dim: int = 256
    transformer_layers: int = 10
    transformer_heads: int = 4
    transformer_dropout: float = 0.1

    # Actor-Critic
    imagination_horizon: int = 15
    discount: float = 0.99
    td_lambda: float = 0.9
    entropy_coef: float = 0.01
    actor_hidden_size: int = 512
    actor_layers: int = 4
    value_hidden_size: int = 512
    value_layers: int = 3

    # Training
    total_epochs: int = 600
    collection_epochs: int = 500
    env_steps_per_epoch: int = 200
    training_steps_per_epoch: int = 250
    model_learning_rate: float = 1e-4
    actor_learning_rate: float = 1e-4
    value_learning_rate: float = 1e-4
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    weight_decay: float = 0.01
    max_grad_norm: float = 10.0
    collect_epsilon: float = 0.1
    eval_temperature: float = 0.1
    start_autoencoder_after: int = 1
    start_transformer_after: int = 15
    start_actor_critic_after: int = 35
    autoencoder_batch_size: int = 256
    transformer_batch_size: int = 64
    actor_critic_batch_size: int = 64

    # Atari 100k Benchmark
    atari_100k: bool = True
    max_env_steps: int = 100000
    env_backend: str = "gym"
    env_name: str = "ALE/Pong-v5"
    action_repeat: int = 4
    max_episode_steps: int = 108000
    num_eval_episodes: int = 10
    eval_max_episode_steps: int = 27000
    buffer_capacity: int = 100000
    num_workers: int = 4
    prefetch_factor: int = 2

    # Logging
    log_dir: str = "./logs"
    checkpoint_dir: str = "./checkpoints"
    use_wandb: bool = True
    wandb_project: str = "iris"
    wandb_entity: str = ""

    # Operator parameters
    operator_seq_length: int = 512
    operator_vocab_size: int = 32000

DiamondConfig#

Configuration for Diamond (Diffusion + RL) training.

@dataclass
class DiamondConfig:
    # Preset
    preset: Optional[str] = None  # "small", "medium", "large"

    # Environment
    game: str = "Breakout-v5"
    obs_size: int = 64
    frameskip: int = 4
    max_noop: int = 30
    terminate_on_life_loss: bool = True
    reward_clip: List[int] = field(default_factory=lambda: [-1, 0, 1])
    num_conditioning_frames: int = 4

    # Diffusion Model
    diffusion_channels: List[int] = field(default_factory=lambda: [64, 64, 64, 64])
    diffusion_res_blocks: int = 2
    diffusion_cond_dim: int = 256
    sigma_data: float = 0.5
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    rho: int = 7
    p_mean: float = -0.4
    p_std: float = 1.2
    sampling_method: str = "euler"
    num_sampling_steps: int = 3

    # Reward/Termination Model
    reward_channels: List[int] = field(default_factory=lambda: [32, 32, 32, 32])
    reward_res_blocks: int = 2
    reward_cond_dim: int = 128
    reward_lstm_dim: int = 512
    burn_in_length: int = 4

    # RL Agent
    actor_channels: List[int] = field(default_factory=lambda: [32, 32, 64, 64])
    actor_res_blocks: int = 1
    actor_lstm_dim: int = 512

    # Training
    num_epochs: int = 1000
    training_steps_per_epoch: int = 400
    batch_size: int = 32
    environment_steps_per_epoch: int = 100
    epsilon_greedy: float = 0.01
    imagination_horizon: int = 15
    discount_factor: float = 0.985
    entropy_weight: float = 0.001
    lambda_returns: float = 0.95
    learning_rate: float = 1e-4
    adam_epsilon: float = 1e-8
    weight_decay_diffusion: float = 1e-2
    weight_decay_reward: float = 1e-2
    weight_decay_actor: float = 0.0

    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Logging
    log_interval: int = 10
    eval_interval: int = 50
    save_interval: int = 100
    num_seeds: int = 5
    seed: int = 0

    # Operator parameters
    operator_state_dim: int = 32
    operator_action_dim: int = 4

Usage Patterns#

Basic Configuration#

from world_models.configs import DreamerConfig

cfg = DreamerConfig()
cfg.env = "walker-walk"
cfg.total_steps = 1_000_000

Environment-Specific Configs#

# DMC
cfg.env_backend = "dmc"
cfg.env = "walker-walk"

# Gym
cfg.env_backend = "gym"
cfg.env = "Pendulum-v1"

# Unity
cfg.env_backend = "unity_mlagents"
cfg.unity_file_name = "env.exe"

Training Configs#

# Basic training
cfg.batch_size = 50
cfg.learning_rate = 6e-4
cfg.total_steps = 5_000_000

# Logging
cfg.enable_wandb = True
cfg.wandb_project = "my_project"

# Checkpointing
cfg.save_every = 100_000

Advanced Configs#

# Custom model sizes
cfg.obs_embed_size = 2048
cfg.num_units = 600
cfg.deter_size = 300

# Exploration
cfg.expl_amount = 0.5
cfg.expl_decay = 0.0001

# Loss weights
cfg.kl_scale = 0.1
cfg.free_nats = 1.0