Source code for world_models.utils.logging_utils

"""Logging, metrics, and numerical-safety helpers for torchwm."""

from __future__ import annotations

import functools
import importlib
import importlib.util
import json
import logging
import os
import time
from collections.abc import Mapping
from typing import Any, Generator, Optional

import torch

_PACKAGE_LOGGER_NAME = "world_models"


[docs] def get_package_logger(name: str | None = None) -> logging.Logger: """Return a logger under the ``world_models`` package namespace.""" if not name: return logging.getLogger(_PACKAGE_LOGGER_NAME) if name == _PACKAGE_LOGGER_NAME or name.startswith(f"{_PACKAGE_LOGGER_NAME}."): return logging.getLogger(name) return logging.getLogger(f"{_PACKAGE_LOGGER_NAME}.{name}")
[docs] def setup_logging( name: str = _PACKAGE_LOGGER_NAME, level: str | int = "INFO", log_file: Optional[str] = None, fmt: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) -> logging.Logger: """Set up structured package logging with optional file output. Args: name: Logger name to configure. Defaults to the package logger. level: Logging level name or numeric value. log_file: Optional file path for a file handler. fmt: ``logging.Formatter`` format string. """ logger = get_package_logger(name) resolved_level = ( logging.getLevelName(level.upper()) if isinstance(level, str) else level ) if isinstance(resolved_level, str): raise ValueError(f"Unknown logging level: {level}") logger.setLevel(resolved_level) logger.propagate = False for handler in logger.handlers[:]: logger.removeHandler(handler) formatter = logging.Formatter(fmt) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) if log_file: directory = os.path.dirname(log_file) if directory: os.makedirs(directory, exist_ok=True) file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger
def _to_scalar(value: Any) -> Any: """Convert tensors/numpy scalars to JSON-serializable scalar values.""" if isinstance(value, torch.Tensor): if value.numel() == 1: return value.detach().item() return value.detach().cpu().tolist() if hasattr(value, "item"): return value.item() return value def _load_summary_writer() -> Any | None: """Return a TensorBoard SummaryWriter class when an implementation exists.""" if importlib.util.find_spec("torch.utils.tensorboard") is not None: tensorboard_module = importlib.import_module("torch.utils.tensorboard") return tensorboard_module.SummaryWriter if importlib.util.find_spec("tensorboardX") is not None: tensorboardx_module = importlib.import_module("tensorboardX") return tensorboardx_module.SummaryWriter return None def _prepare_tensorboard_video(video: Any) -> torch.Tensor: """Convert a video to TensorBoard's ``(N, T, C, H, W)`` layout.""" video_tensor = video if isinstance(video, torch.Tensor) else torch.as_tensor(video) video_tensor = video_tensor.detach().cpu() if video_tensor.ndim == 4: video_tensor = video_tensor.unsqueeze(0) if video_tensor.ndim != 5: raise ValueError( "TensorBoard videos must have shape (T,H,W,C), (T,C,H,W), " "(N,T,H,W,C), or (N,T,C,H,W)." ) if video_tensor.shape[-1] in (1, 3, 4): video_tensor = video_tensor.permute(0, 1, 4, 2, 3) if video_tensor.dtype == torch.uint8: video_tensor = video_tensor.to(torch.float32) / 255.0 else: video_tensor = video_tensor.to(torch.float32) return torch.clamp(video_tensor, 0.0, 1.0)
[docs] class MetricsLogger: """Fan-out metric logger for console, JSONL, TensorBoard, and W&B. JSONL output is enabled by default because it is dependency-free and easy to reload for offline plots. TensorBoard and W&B are optional and activated only when requested and available/configured. """ def __init__( self, log_dir: str, *, logger: logging.Logger | None = None, enable_console: bool = True, enable_jsonl: bool = True, jsonl_filename: str = "metrics.jsonl", enable_tensorboard: bool = False, enable_wandb: bool = False, wandb_api_key: str = "", wandb_project: str = "torchwm", wandb_entity: str = "", run_name: str | None = None, ) -> None: self.log_dir = log_dir self.logger = logger or get_package_logger("metrics") self.enable_console = enable_console self.enable_jsonl = enable_jsonl self.jsonl_path = os.path.join(log_dir, jsonl_filename) self._jsonl_file = None self._tb_writer = None self._wandb_run = None os.makedirs(log_dir, exist_ok=True) if self.enable_jsonl: self._jsonl_file = open(self.jsonl_path, "a", encoding="utf-8") if enable_tensorboard: summary_writer = _load_summary_writer() if summary_writer is None: self.logger.warning("TensorBoard logging requested but unavailable") else: self._tb_writer = summary_writer(log_dir=log_dir) if enable_wandb: if not wandb_api_key: raise ValueError("WandB API key is required when enable_wandb is True") if importlib.util.find_spec("wandb") is None: raise ImportError("wandb is not installed") wandb = importlib.import_module("wandb") os.environ["WANDB_API_KEY"] = wandb_api_key self._wandb_run = wandb.init( project=wandb_project, entity=wandb_entity or None, dir=log_dir, name=run_name or os.path.basename(log_dir), )
[docs] def log( self, metrics: Mapping[str, Any], step: int, prefix: str | None = None ) -> dict[str, Any]: """Log scalar metrics to every enabled sink.""" normalized = {} for key, value in metrics.items(): metric_key = f"{prefix}/{key}" if prefix else str(key) normalized[metric_key] = _to_scalar(value) if self.enable_console and normalized: formatted = ", ".join(f"{key}={value}" for key, value in normalized.items()) self.logger.info("step=%s %s", step, formatted) if self._jsonl_file is not None: self._jsonl_file.write( json.dumps( {"time": time.time(), "step": int(step), **normalized}, sort_keys=True, default=str, ) + "\n" ) if self._tb_writer is not None: for key, value in normalized.items(): if isinstance(value, (int, float)): self._tb_writer.add_scalar(key, value, step) if self._wandb_run is not None: self._wandb_run.log(normalized, step=step) return normalized
[docs] def log_video(self, name: str, video: Any, step: int, fps: int = 20) -> None: """Log a video to TensorBoard and W&B when enabled.""" if self._tb_writer is not None: self._tb_writer.add_video( name, _prepare_tensorboard_video(video), global_step=step, fps=fps ) if self._wandb_run is not None: wandb = importlib.import_module("wandb") self._wandb_run.log({name: wandb.Video(video, fps=fps)}, step=step)
[docs] def flush(self) -> None: if self._jsonl_file is not None: self._jsonl_file.flush() if self._tb_writer is not None: self._tb_writer.flush()
[docs] def close(self) -> None: self.flush() if self._jsonl_file is not None: self._jsonl_file.close() self._jsonl_file = None if self._tb_writer is not None: self._tb_writer.close() self._tb_writer = None
[docs] def collect_system_stats(device: torch.device | str | None = None) -> dict[str, float]: """Collect CPU/GPU memory and CUDA utilization counters when available.""" stats: dict[str, float] = {} if importlib.util.find_spec("psutil") is not None: psutil = importlib.import_module("psutil") vm = psutil.virtual_memory() stats.update( { "system/cpu_percent": float(psutil.cpu_percent(interval=None)), "system/ram_used_mb": float(vm.used / (1024**2)), "system/ram_available_mb": float(vm.available / (1024**2)), "system/ram_percent": float(vm.percent), } ) torch_device = torch.device(device) if device is not None else None if torch.cuda.is_available() and ( torch_device is None or torch_device.type == "cuda" ): cuda_index = ( torch_device.index if torch_device and torch_device.index is not None else torch.cuda.current_device() ) stats.update( { "system/gpu_memory_allocated_mb": float( torch.cuda.memory_allocated(cuda_index) / (1024**2) ), "system/gpu_memory_reserved_mb": float( torch.cuda.memory_reserved(cuda_index) / (1024**2) ), "system/gpu_max_memory_allocated_mb": float( torch.cuda.max_memory_allocated(cuda_index) / (1024**2) ), } ) if hasattr(torch.cuda, "utilization"): stats["system/gpu_utilization_percent"] = float( torch.cuda.utilization(cuda_index) ) return stats
def _iter_tensors(value: Any) -> Generator[torch.Tensor, None, None]: if isinstance(value, torch.Tensor): yield value elif isinstance(value, Mapping): for item in value.values(): yield from _iter_tensors(item) elif isinstance(value, (tuple, list)): for item in value: yield from _iter_tensors(item)
[docs] def assert_finite_values(value: Any, name: str = "value") -> Any: """Raise ``FloatingPointError`` if any tensor contains NaN or Inf.""" for tensor in _iter_tensors(value): if not torch.isfinite(tensor).all(): raise FloatingPointError(f"Non-finite tensor detected in {name}") return value
[docs] def assert_finite(fn: Any) -> Any: """Decorator that validates tensor outputs from loss functions are finite.""" @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: result = fn(*args, **kwargs) return assert_finite_values(result, getattr(fn, "__qualname__", fn.__name__)) return wrapper