Source code for world_models.blocks.mhsa

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class MultiHeadSelfAttention(nn.Module): """Multi-head scaled dot-product self-attention over sequence tokens. This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block. """ def __init__(self, d: int, n_heads: int = 2) -> None: super(MultiHeadSelfAttention, self).__init__() assert d % n_heads == 0, "d must be divisible by n_heads" self.d = d self.n_heads = n_heads self.d_head = d // n_heads self.W_q = nn.Linear(d, d) self.W_k = nn.Linear(d, d) self.W_v = nn.Linear(d, d) self.W_o = nn.Linear(d, d)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, D = x.size() Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) Q = Q.view(B, T, self.n_heads, self.d_head).transpose(1, 2) K = K.view(B, T, self.n_heads, self.d_head).transpose(1, 2) V = V.view(B, T, self.n_heads, self.d_head).transpose(1, 2) context = F.scaled_dot_product_attention(Q, K, V) context = context.transpose(1, 2).contiguous().view(B, T, D) out = self.W_o(context) return out