Source code for world_models.layers.rms_norm
import torch
import torch.nn as nn
[docs]
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization with a learned gain parameter.
RMSNorm rescales activations using their RMS magnitude without centering,
providing a lightweight normalization alternative to LayerNorm.
"""
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm_x = x.norm(2, dim=-1, keepdim=True)
rms = norm_x * (x.size(-1) ** -0.5)
x_normed = x / (rms + self.eps)
return x_normed * self.g