Source code for world_models.layers.AdaLNNorm

import torch.nn as nn
from world_models.layers.RMSNorm import RMSNorm


[docs] class AdaLNNormalization(nn.Module): """Adaptive layer normalization conditioned on an external embedding. The module applies RMS normalization and predicts per-channel scale/shift from a conditioning vector (for example diffusion timestep embeddings). """ def __init__(self, d_model, t_dim): super().__init__() self.norm = RMSNorm(d_model) self.to_scale_shift = nn.Linear(t_dim, d_model * 2)
[docs] def forward(self, x, t_emb): h = self.norm(x) scale_shift = self.to_scale_shift(t_emb).unsqueeze(1) scale, shift = scale_shift.chunk(2, dim=-1) while scale.dim() < h.dim(): scale = scale.unsqueeze(1) shift = shift.unsqueeze(1) return h * (1 + scale) + shift