Source code for world_models.models.vit

import math
from functools import partial
from typing import Any
import numpy as np

import torch
import torch.nn as nn
from world_models.utils.jepa_utils import trunc_normal_, repeat_interleave_batch
from world_models.utils.utils import apply_masks


[docs] def get_2d_sincos_pos_embed( embed_dim: int, grid_size: int, cls_token: bool = False ) -> np.ndarray: """Generate fixed 2D sine/cosine positional embeddings on a square patch grid. Returns NumPy embeddings used to initialize non-trainable transformer position encodings, with optional prepended class-token embedding. """ grid_h = np.arange(grid_size, dtype=float) grid_w = np.arange(grid_size, dtype=float) grid: Any = np.meshgrid(grid_w, grid_h) grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed
[docs] def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray: """Build 2D sine/cosine embeddings from precomputed meshgrid coordinates. The final embedding concatenates independent encodings for vertical and horizontal coordinates. """ assert embed_dim % 2 == 0 emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return pos_embed
[docs] def get_1d_sincos_pos_embed( embed_dim: int, grid_size: int, cls_token: bool = False ) -> np.ndarray: """Generate 1D sine/cosine positional embeddings for integer positions. Useful for sequence-style positional encoding and as a building block for 2D embedding construction. """ grid = np.arange(grid_size, dtype=float) pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed
[docs] def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: """Generate 1D sine/cosine positional embeddings from explicit positions. Positions are projected onto a log-frequency basis and encoded with sine and cosine components. """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2.0 omega = 1.0 / (10000**omega) pos = pos.reshape(-1) out = np.einsum("m,d->md", pos, omega) emb_sin = np.sin(out) emb_cos = np.cos(out) emb = np.concatenate([emb_sin, emb_cos], axis=1) return emb
[docs] def drop_path( x: torch.Tensor, drop_prob: float = 0.0, training: bool = False ) -> torch.Tensor: """Apply stochastic depth (DropPath) regularization to residual branches. Randomly drops entire residual paths per sample during training and scales the surviving activations to preserve expected magnitude. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() output = x.div(keep_prob) * random_tensor return output
[docs] class DropPath(nn.Module): """Module wrapper around the functional `drop_path` stochastic depth utility.""" def __init__(self, drop_prob: float | None = None) -> None: super(DropPath, self).__init__() self.drop_prob = drop_prob
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob or 0.0, self.training)
[docs] class MLP(nn.Module): """Two-layer feed-forward network used inside transformer blocks. Applies linear projection, activation, dropout, and output projection in the standard Vision Transformer MLP pattern. """ def __init__( self, in_features: int, hidden_features: int | None = None, out_features: int | None = None, act_layer: type[nn.Module] = nn.GELU, drop: float = 0.0, ) -> None: super(MLP, self).__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs] class Attention(nn.Module): """Multi-head self-attention block for token sequences. Computes QKV projections, scaled dot-product attention, and output projection with configurable dropout. """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_scale: float | None = None, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x
[docs] class Block(nn.Module): """Transformer encoder block combining attention and MLP residual branches. Each branch uses pre-normalization and optional stochastic depth. """ def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_scale: float | None = None, drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, act_layer: type[nn.Module] = nn.GELU, norm_layer: type[nn.Module] = nn.LayerNorm, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MLP( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
[docs] class PatchEmbed(nn.Module): """Image to Patch Embedding""" def __init__( self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, ) -> None: super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_size, stride=patch_size )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) return x
[docs] class ConvEmbed(nn.Module): """ 3x3 Convolution stems for ViT following ViTC models """ def __init__( self, channels: list[int], strides: list[int], img_size: int = 224, in_chans: int = 3, batch_norm: bool = True, ) -> None: super().__init__() # Build the stems stem = [] channels = [in_chans] + channels for i in range(len(channels) - 2): stem += [ nn.Conv2d( channels[i], channels[i + 1], kernel_size=3, stride=strides[i], padding=1, bias=(not batch_norm), ) ] if batch_norm: stem += [nn.BatchNorm2d(channels[i + 1])] # type: ignore[list-item] stem += [nn.ReLU()] # type: ignore[list-item] stem += [ nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1]) ] self.stem = nn.Sequential(*stem) # Comptute the number of patches stride_prod = int(np.prod(strides)) self.num_patches = (int(img_size[0]) // stride_prod) ** 2 # type: ignore[index]
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: p = self.stem(x) return p.flatten(2).transpose(1, 2)
[docs] class VisionTransformerPredictor(nn.Module): """Vision Transformer""" def __init__( self, num_patches: int, embed_dim: int = 768, predictor_embed_dim: int = 384, depth: int = 6, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float | None = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: type[nn.Module] = nn.LayerNorm, init_std: float = 0.02, **kwargs: Any, ) -> None: super().__init__() self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule # -- self.predictor_pos_embed = nn.Parameter( torch.zeros(1, num_patches, predictor_embed_dim), requires_grad=False ) predictor_pos_embed = get_2d_sincos_pos_embed( self.predictor_pos_embed.shape[-1], int(num_patches**0.5), cls_token=False ) self.predictor_pos_embed.data.copy_( torch.from_numpy(predictor_pos_embed).float().unsqueeze(0) ) # -- self.predictor_blocks = nn.ModuleList( [ Block( dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, ) for i in range(depth) ] ) self.predictor_norm = norm_layer(predictor_embed_dim) self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) # ------ self.init_std = init_std trunc_normal_(self.mask_token, std=self.init_std) self.apply(self._init_weights) self.fix_init_weight()
[docs] def fix_init_weight(self) -> None: def rescale(param: torch.Tensor, layer_id: int) -> None: param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.predictor_blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) # type: ignore[arg-type, union-attr] rescale(layer.mlp.fc2.weight.data, layer_id + 1) # type: ignore[arg-type, union-attr]
def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.init_std) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=self.init_std) if m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def forward( self, x: torch.Tensor, masks_x: torch.Tensor | list[torch.Tensor], masks: torch.Tensor | list[torch.Tensor], ) -> torch.Tensor: assert (masks is not None) and (masks_x is not None), ( "Cannot run predictor without mask indices" ) if not isinstance(masks_x, list): masks_x = [masks_x] if not isinstance(masks, list): masks = [masks] # -- Batch Size B = len(x) // len(masks_x) # -- map from encoder-dim to pedictor-dim x = self.predictor_embed(x) # -- add positional embedding to x tokens x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) x += apply_masks(x_pos_embed, masks_x) _, N_ctxt, D = x.shape # -- concat mask tokens to x pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) pos_embs = apply_masks(pos_embs, masks) pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x)) # -- pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1) # -- pred_tokens += pos_embs x = x.repeat(len(masks), 1, 1) x = torch.cat([x, pred_tokens], dim=1) # -- fwd prop for blk in self.predictor_blocks: x = blk(x) x = self.predictor_norm(x) # -- return preds for mask tokens x = x[:, N_ctxt:] x = self.predictor_proj(x) return x
[docs] class VisionTransformer(nn.Module): """Vision Transformer""" def __init__( self, img_size: list[int] = [224], patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, predictor_embed_dim: int = 384, depth: int = 12, predictor_depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float | None = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: type[nn.Module] = nn.LayerNorm, init_std: float = 0.02, **kwargs: Any, ) -> None: super().__init__() self.num_features = self.embed_dim = embed_dim self.num_heads = num_heads # -- self.patch_embed = PatchEmbed( img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches # -- self.pos_embed = nn.Parameter( torch.zeros(1, num_patches, embed_dim), requires_grad=False ) pos_embed = get_2d_sincos_pos_embed( self.pos_embed.shape[-1], int(self.patch_embed.num_patches**0.5), cls_token=False, ) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # -- dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = nn.ModuleList( [ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, ) for i in range(depth) ] ) self.norm = norm_layer(embed_dim) # ------ self.init_std = init_std self.apply(self._init_weights) self.fix_init_weight()
[docs] def fix_init_weight(self) -> None: def rescale(param: torch.Tensor, layer_id: int) -> None: param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) # type: ignore[arg-type, union-attr] rescale(layer.mlp.fc2.weight.data, layer_id + 1) # type: ignore[arg-type, union-attr]
def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=self.init_std) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=self.init_std) if m.bias is not None: nn.init.constant_(m.bias, 0)
[docs] def forward( self, x: torch.Tensor, masks: torch.Tensor | list[torch.Tensor] | None = None ) -> torch.Tensor: if masks is not None: if not isinstance(masks, list): masks = [masks] # -- patchify x x = self.patch_embed(x) B, N, D = x.shape # -- add positional embedding to x pos_embed = self.interpolate_pos_encoding(x, self.pos_embed) x = x + pos_embed # -- mask x if masks is not None: x = apply_masks(x, masks) # -- fwd prop for i, blk in enumerate(self.blocks): x = blk(x) if self.norm is not None: x = self.norm(x) return x
[docs] def interpolate_pos_encoding( self, x: torch.Tensor, pos_embed: torch.Tensor ) -> torch.Tensor: npatch = x.shape[1] - 1 N = pos_embed.shape[1] - 1 if npatch == N: return pos_embed class_emb = pos_embed[:, 0] pos_embed = pos_embed[:, 1:] dim = x.shape[-1] pos_embed = nn.functional.interpolate( pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( 0, 3, 1, 2 ), scale_factor=math.sqrt(npatch / N), mode="bicubic", ) pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
[docs] def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor: """Factory for a JEPA predictor transformer with sensible defaults.""" model = VisionTransformerPredictor( mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[arg-type] **kwargs, ) return model
[docs] def vit_tiny(patch_size: int = 16, **kwargs: Any) -> Any: """Factory for a tiny Vision Transformer encoder backbone.""" model = VisionTransformer( patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[arg-type] **kwargs, ) return model
[docs] def vit_small(patch_size: int = 16, **kwargs: Any) -> Any: """Factory for a small Vision Transformer encoder backbone.""" model = VisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[arg-type] **kwargs, ) return model
[docs] def vit_base(patch_size: int = 16, **kwargs: Any) -> Any: """Factory for a base Vision Transformer encoder backbone.""" model = VisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[arg-type] **kwargs, ) return model
[docs] def vit_large(patch_size: int = 16, **kwargs: Any) -> Any: """Factory for a large Vision Transformer encoder backbone.""" model = VisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[arg-type] **kwargs, ) return model
[docs] def vit_huge(patch_size: int = 16, **kwargs: Any) -> Any: """Factory for a huge Vision Transformer encoder backbone.""" model = VisionTransformer( patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[arg-type] **kwargs, ) return model
[docs] def vit_giant(patch_size: int = 16, **kwargs: Any) -> Any: """Factory for a giant Vision Transformer encoder backbone.""" model = VisionTransformer( patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48 / 11, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[arg-type] **kwargs, ) return model
VIT_EMBED_DIMS = { "vit_tiny": 192, "vit_small": 384, "vit_base": 768, "vit_large": 1024, "vit_huge": 1280, "vit_giant": 1408, }