Source code for world_models.utils.memory_utils
import torch
import torch.nn as nn
from typing import Optional
[docs]
def apply_gradient_checkpointing(model: nn.Module, checkpoint_ratio: float = 0.5):
"""Apply gradient checkpointing to reduce memory usage during training."""
# Some models expose gradient_checkpointing_enable as a callable while
# others may have an attribute with the same name that isn't callable.
# Use getattr and only call when it's actually callable to avoid mypy
# complaining about "Tensor" not callable at type-check time.
gc_enable = getattr(model, "gradient_checkpointing_enable", None)
if callable(gc_enable):
gc_enable()
else:
# For custom modules, apply selective checkpointing
for name, module in model.named_modules():
if isinstance(module, nn.TransformerEncoderLayer):
# Wrap the original forward in a callable that uses
# torch.utils.checkpoint.checkpoint. We capture the original
# method to avoid recursive lookup and assign a plain
# function to the instance attribute (allowed at runtime).
from typing import Callable, Any
orig_forward: Callable[..., Any] = module.forward # capture
def _checkpointed_forward(*args: Any, **kwargs: Any) -> Any:
# Torch's checkpoint API is present at runtime but some stubs
# do not expose it. Ignore attribute errors from type-checker
# here while preserving runtime behavior.
return torch.utils.checkpoint.checkpoint( # type: ignore[attr-defined]
orig_forward, *args, **kwargs, use_reentrant=False
)
# Assigning a function to an instance method is a runtime pattern
# used to wrap behavior; mypy may complain about assigning to
# a method attribute, so silence that specific check.
setattr(module, "forward", _checkpointed_forward) # type: ignore[assignment]
elif hasattr(module, "checkpoint_forward"):
# Create a wrapper that calls checkpoint at runtime. Do not
# call checkpoint here (that would execute the function and
# assign a Tensor to `forward`). Some torch stubs do not
# expose `utils.checkpoint`, so silence attribute checks.
from typing import Any
def _checkpointed_forward2(*args: Any, **kwargs: Any) -> Any:
# Use a targeted ignore for the missing `checkpoint` attr in
# some torch stubs while preserving runtime behaviour.
return torch.utils.checkpoint.checkpoint( # type: ignore[attr-defined]
module.checkpoint_forward, *args, **kwargs, use_reentrant=False
)
setattr(module, "forward", _checkpointed_forward2) # type: ignore[assignment]
[docs]
def enable_mixed_precision(
model: nn.Module, scaler: Optional[torch.amp.GradScaler] = None
):
"""Enable mixed precision training."""
if scaler is None:
scaler = torch.amp.GradScaler()
return scaler
[docs]
def optimize_memory_efficient_ops():
"""Set PyTorch for memory-efficient operations."""
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")