Source code for world_models.configs.dit_config
from dataclasses import dataclass, replace
[docs]
@dataclass
class DiTConfig:
"""Default configuration values for Diffusion Transformer (DiT) training.
The fields define dataset selection, model architecture, diffusion schedule,
optimization hyperparameters, and output paths used by the built-in
training entrypoints.
"""
DATASET: str = "CIFAR10"
BATCH: int = 128
EPOCHS: int = 3
LR: float = 2e-4
IMG_SIZE: int = 32
CHANNELS: int = 3
PATCH: int = 4
WIDTH: int = 384
DEPTH: int = 6
HEADS: int = 6
DROP: float = 0.1
BETA_START: float = 1e-4
BETA_END: float = 0.02
TIMESTEPS: int = 1000
EMA: bool = True
EMA_DECAY: float = 0.999
WORKDIR: str = "./dit_demo"
ROOT_PATH: str = "./data"
[docs]
def get_dit_config(**overrides):
"""
Returns a DiTConfig instance with default values overridden by the provided keyword arguments.
Example usage:
cfg = get_dit_config(BATCH=64, EPOCHS=10, LR=1e-3)
"""
return replace(DiTConfig(), **overrides)