Source code for world_models.utils.train_utils

"""Training utilities for World Models.

This module provides utility classes for training neural networks including
early stopping and learning rate scheduling.
"""

from typing import Any
from functools import partial
from torch.optim import Optimizer


[docs] class EarlyStopping: """Early stopping handler to stop training when validation metric stops improving. This class monitors a validation metric and stops training when no improvement is seen for a specified number of epochs (patience). This helps prevent overfitting and reduces unnecessary computation. Args: mode: One of 'min' or 'max'. In 'min' mode, training stops when the metric stops decreasing; in 'max' mode, when it stops increasing. patience: Number of epochs with no improvement after which to stop training. threshold: Minimum change to qualify as an improvement. threshold_mode: One of 'rel' or 'abs'. In 'rel' mode, dynamic threshold is relative to best value; in 'abs' mode, it's absolute. Attributes: stop: Property that returns True if training should stop. Example: >>> early_stopping = EarlyStopping(mode='min', patience=10) >>> for epoch in range(100): ... val_loss = validate() ... early_stopping.step(val_loss) ... if early_stopping.stop: ... print(f"Stopped at epoch {epoch}") ... break """ def __init__( self, mode: str = "min", patience: int = 10, threshold: float = 1e-4, threshold_mode: str = "rel", ) -> None: self.patience = patience self.mode = mode self.threshold = threshold self.threshold_mode = threshold_mode self.last_epoch = -1 self._init_is_better(mode, threshold, threshold_mode) self._reset() def _reset(self) -> None: """Reset the internal state for a new training run.""" self.best: float = self.mode_worse self.num_bad_epochs: int = 0
[docs] def step(self, metrics: float, epoch: int | None = None) -> None: """Update early stopping state with new metric value. Args: metrics: Current epoch's metric value. epoch: Current epoch number. If None, auto-increments from last epoch. """ current = metrics if epoch is None: epoch = self.last_epoch = self.last_epoch + 1 self.last_epoch = epoch if self.is_better(current, self.best): self.best = current self.num_bad_epochs = 0 else: self.num_bad_epochs += 1
@property def stop(self) -> bool: """bool: True if training should stop due to no improvement.""" return self.num_bad_epochs > self.patience def _cmp( self, mode: str, threshold_mode: str, threshold: float, a: float, best: float ) -> bool: """Compare two values based on mode and threshold settings.""" if mode == "min" and threshold_mode == "rel": rel_epsilon = 1.0 - threshold return a < best * rel_epsilon elif mode == "min" and threshold_mode == "abs": return a < best - threshold elif mode == "max" and threshold_mode == "rel": rel_epsilon = 1.0 + threshold return a > best * rel_epsilon return a > best + threshold def _init_is_better(self, mode: str, threshold: float, threshold_mode: str) -> None: """Initialize the comparison function.""" if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if threshold_mode not in {"rel", "abs"}: raise ValueError("threshold mode " + threshold_mode + " is unknown!") if mode == "min": self.mode_worse = float("inf") else: self.mode_worse = -float("inf") self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
[docs] def state_dict(self) -> dict: """Get state dictionary for checkpointing. Returns: Dictionary containing early stopping state. """ return { key: value for key, value in self.__dict__.items() if key not in ("is_better",) }
[docs] def load_state_dict(self, state_dict: dict) -> None: """Load state from checkpoint. Args: state_dict: Dictionary containing early stopping state. """ self.__dict__.update(state_dict) self._init_is_better(self.mode, self.threshold, self.threshold_mode)
[docs] class ReduceLROnPlateau: """Reduce learning rate when a metric stops improving. This scheduler reduces the learning rate by a factor when a validation metric stops improving for a specified number of epochs. This helps models converge better by reducing the step size as they approach optimal weights. Args: optimizer: The PyTorch optimizer to adjust. mode: One of 'min' or 'max'. In 'min' mode, lr is reduced when metric stops decreasing; in 'max' mode, when it stops increasing. factor: Factor by which to reduce the learning rate. patience: Number of epochs with no improvement after which to reduce lr. threshold: Minimum change to qualify as an improvement. threshold_mode: One of 'rel' or 'abs'. min_lr: Minimum learning rate to reduce to. eps: Minimum decay for lr. Attributes: lr: Current learning rates for each parameter group. Example: >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001) >>> scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) >>> for epoch in range(100): ... train_loss = train() ... val_loss = validate() ... scheduler.step(val_loss) ... if scheduler.stop: ... break """ def __init__( self, optimizer: Optimizer, mode: str = "min", factor: float = 0.1, patience: int = 10, threshold: float = 1e-4, threshold_mode: str = "rel", min_lr: float = 0, eps: float = 1e-8, ) -> None: self.optimizer = optimizer self.factor = factor self.min_lr = min_lr self.eps = eps self.mode = mode self.patience = patience self.threshold = threshold self.threshold_mode = threshold_mode self.last_epoch = -1 self._init_is_better(mode, threshold, threshold_mode) self._reset() def _reset(self) -> None: """Reset the internal state for a new training run.""" self.best: float = self.mode_worse self.num_bad_epochs: int = 0
[docs] def step(self, metrics: float, epoch: int | None = None) -> None: """Update learning rate based on metric value. Args: metrics: Current epoch's metric value. epoch: Current epoch number. If None, auto-increments from last epoch. """ current = metrics if epoch is None: epoch = self.last_epoch = self.last_epoch + 1 self.last_epoch = epoch if self.is_better(current, self.best): self.best = current self.num_bad_epochs = 0 else: self.num_bad_epochs += 1 if self.num_bad_epochs > self.patience: self._reduce_lr() self.num_bad_epochs = 0
def _reduce_lr(self) -> None: """Reduce learning rate for all parameter groups.""" for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group["lr"]) new_lr = max(old_lr * self.factor, self.min_lr) if old_lr - new_lr > self.eps: param_group["lr"] = new_lr @property def lr(self) -> list: """list: Current learning rates for each parameter group.""" return [param_group["lr"] for param_group in self.optimizer.param_groups] def _cmp( self, mode: str, threshold_mode: str, threshold: float, a: float, best: float ) -> bool: """Compare two values based on mode and threshold settings.""" if mode == "min" and threshold_mode == "rel": rel_epsilon = 1.0 - threshold return a < best * rel_epsilon elif mode == "min" and threshold_mode == "abs": return a < best - threshold elif mode == "max" and threshold_mode == "rel": rel_epsilon = 1.0 + threshold return a > best * rel_epsilon return a > best + threshold def _init_is_better(self, mode: str, threshold: float, threshold_mode: str) -> None: """Initialize the comparison function.""" if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if threshold_mode not in {"rel", "abs"}: raise ValueError("threshold mode " + threshold_mode + " is unknown!") self.mode_worse: float if mode == "min": self.mode_worse = float("inf") else: self.mode_worse = -float("inf") self.is_better: Any = partial(self._cmp, mode, threshold_mode, threshold)
[docs] def state_dict(self) -> dict: """Get state dictionary for checkpointing. Returns: Dictionary containing scheduler state. """ return { key: value for key, value in self.__dict__.items() if key not in ("is_better",) }
[docs] def load_state_dict(self, state_dict: dict) -> None: """Load state from checkpoint. Args: state_dict: Dictionary containing scheduler state. """ self.__dict__.update(state_dict) self._init_is_better(self.mode, self.threshold, self.threshold_mode)