Source code for world_models.layers.RMSNorm

import torch
import torch.nn as nn


[docs] class RMSNorm(nn.Module): """Root Mean Square Layer Normalization with a learned gain parameter. RMSNorm rescales activations using their RMS magnitude without centering, providing a lightweight normalization alternative to LayerNorm. """ def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(dim))
[docs] def forward(self, x): norm_x = x.norm(2, dim=-1, keepdim=True) rms = norm_x * (x.size(-1) ** -0.5) x_normed = x / (rms + self.eps) return x_normed * self.g