Source code for world_models.blocks.mhsa

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

[docs] class MultiHeadSelfAttention(nn.Module): """Multi-head scaled dot-product self-attention over sequence tokens. This module projects the input sequence into query/key/value heads, performs attention independently per head, and merges the heads back into the original feature dimension. It is used as a lightweight transformer attention block. """ def __init__(self, d, n_heads=2): super(MultiHeadSelfAttention, self).__init__() assert d % n_heads == 0, "d must be divisible by n_heads" self.d = d self.n_heads = n_heads self.d_head = d // n_heads self.W_q = nn.Linear(d, d) self.W_k = nn.Linear(d, d) self.W_v = nn.Linear(d, d) self.W_o = nn.Linear(d, d)
[docs] def forward(self, x): B, T, D = x.size() Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) Q = Q.view(B, T, self.n_heads, self.d_head).transpose(1, 2) K = K.view(B, T, self.n_heads, self.d_head).transpose(1, 2) V = V.view(B, T, self.n_heads, self.d_head).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head) attn = F.softmax(scores, dim=-1) context = torch.matmul(attn, V) context = context.transpose(1, 2).contiguous().view(B, T, D) out = self.W_o(context) return out