Source code for world_models.layers.RMSNorm
import torch
import torch.nn as nn
[docs]
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization with a learned gain parameter.
RMSNorm rescales activations using their RMS magnitude without centering,
providing a lightweight normalization alternative to LayerNorm.
"""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
[docs]
def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True)
rms = norm_x * (x.size(-1) ** -0.5)
x_normed = x / (rms + self.eps)
return x_normed * self.g