Source code for world_models.configs.dit_config

from dataclasses import dataclass, replace
from typing import Any

from world_models.configs.serialization import SerializableConfigMixin

_SNAKE_TO_UPPER = {
    "dataset": "DATASET",
    "batch": "BATCH",
    "epochs": "EPOCHS",
    "lr": "LR",
    "img_size": "IMG_SIZE",
    "channels": "CHANNELS",
    "patch_size": "PATCH",
    "width": "WIDTH",
    "depth": "DEPTH",
    "heads": "HEADS",
    "drop": "DROP",
    "beta_start": "BETA_START",
    "beta_end": "BETA_END",
    "timesteps": "TIMESTEPS",
    "ema": "EMA",
    "ema_decay": "EMA_DECAY",
    "workdir": "WORKDIR",
    "root_path": "ROOT_PATH",
}


[docs] @dataclass class DiTConfig(SerializableConfigMixin): """Default configuration values for Diffusion Transformer (DiT) training. The fields define dataset selection, model architecture, diffusion schedule, optimization hyperparameters, and output paths used by the built-in training entrypoints. Field names use UPPER_CASE for backward compatibility with the original DiT codebase. Snake-case aliases are accepted via ``__getattr__`` and ``get_dit_config()``. """ DATASET: str = "CIFAR10" BATCH: int = 128 EPOCHS: int = 3 LR: float = 2e-4 IMG_SIZE: int = 32 CHANNELS: int = 3 PATCH: int = 4 WIDTH: int = 384 DEPTH: int = 6 HEADS: int = 6 DROP: float = 0.1 BETA_START: float = 1e-4 BETA_END: float = 0.02 TIMESTEPS: int = 1000 EMA: bool = True EMA_DECAY: float = 0.999 WORKDIR: str = "./dit_demo" ROOT_PATH: str = "./data" def __getattr__(self, name: str) -> Any: upper = _SNAKE_TO_UPPER.get(name) if upper is not None: return getattr(self, upper) raise AttributeError(f"{type(self).__name__!r} has no attribute {name!r}") def __setattr__(self, name: str, value: Any) -> None: upper = _SNAKE_TO_UPPER.get(name, name) super().__setattr__(upper, value)
[docs] def get_dit_config(**overrides: Any) -> DiTConfig: """ Returns a DiTConfig instance with default values overridden by the provided keyword arguments. Both UPPER_CASE and snake_case override keys are accepted. Example usage: cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3) cfg = get_dit_config(batch=64, epochs=10, lr=1e-3) """ translated = {} for key, value in overrides.items(): translated[_SNAKE_TO_UPPER.get(key, key)] = value return replace(DiTConfig(), **translated)