Source code for simshift.models.mesh.transolver

"""
Transolver: A Fast Transformer Solver for PDEs on General Geometries

Adapted from https://github.com/thuml/Transolver
"""

from functools import partial
from typing import Optional

import torch
import torch.nn as nn
from einops import einsum, rearrange
from torch.nn import functional as F
from torch_geometric.utils import to_dense_batch

from simshift.models.condition import ContinuousSincosEmbed, DiT
from simshift.models.registry import register_model
from simshift.models.utils import MLP


[docs]class TransolverAttention(nn.Module): """ Multi-head self-attention with physics-aware slicing for PDE solvers. :param dim: Input feature dimension. :type dim: int :param num_heads: Number of attention heads. :type num_heads: int :param dropout_prob: Dropout probability after attention. :type dropout_prob: float :param attn_dropout_prob: Dropout within attention weights. :type attn_dropout_prob: float :param slice_base: Base number of slices for token grouping. :type slice_base: int """
[docs] def __init__( self, dim: int = 128, num_heads: int = 4, dropout_prob: float = 0.1, attn_dropout_prob: float = 0.0, slice_base: int = 64, ): super().__init__() assert (dim % num_heads) == 0 self.dim = dim self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.num_heads = num_heads self.slice_base = slice_base self.attn_dropout_prob = attn_dropout_prob self.temperature = nn.Parameter(torch.ones([1, num_heads, 1, 1]) * 0.5) # input projection self.x_proj = nn.Linear(dim, dim) self.fx_proj = nn.Linear(dim, dim) self.slice_proj = nn.Linear(self.head_dim, slice_base) nn.init.orthogonal_(self.slice_proj.weight) # qkv projection self.qkv = nn.Linear(self.head_dim, self.head_dim * 3, bias=False) self.readout = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout_prob))
def forward(self, x: torch.Tensor): # slices x_mid = rearrange(self.x_proj(x), "b n (h c) -> b h n c", c=self.head_dim) fx_mid = rearrange(self.fx_proj(x), "b n (h c) -> b h n c", c=self.head_dim) slice_weights = F.softmax( self.slice_proj(x_mid) / self.temperature, -1 ) # b h n g # in-slice attention scale = (slice_weights.sum(2) + 1e-5)[..., None].repeat(1, 1, 1, self.head_dim) slice_att = einsum(fx_mid, slice_weights, "b h n c, b h n g -> b h g c") / scale # global (across slices) attention qkv = rearrange(self.qkv(slice_att), "b h g (thr c) -> thr b h g c", thr=3) q, k, v = qkv[0], qkv[1], qkv[2] if self.training: att = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_dropout_prob ) else: att = F.scaled_dot_product_attention(q, k, v, dropout_p=0) # merge (cross attention) x = einsum(att, slice_weights, "b h g c, b h n g -> b h n c") x = rearrange(x, "b h n d -> b n (h d)") return self.readout(x)
[docs]class TransolverBlock(nn.Module):
[docs] def __init__( self, dim: int = 128, num_heads: int = 4, dropout_prob: float = 0.1, attn_dropout_prob: float = 0.0, act_fn: nn.Module = nn.SiLU, mlp_ratio: float = 4.0, slice_base: int = 64, norm_layer: nn.Module = nn.LayerNorm, ): super().__init__() self.dim = dim self.norm1 = norm_layer(dim) self.attn = TransolverAttention( dim=dim, num_heads=num_heads, dropout_prob=dropout_prob, attn_dropout_prob=attn_dropout_prob, slice_base=slice_base, ) self.norm2 = norm_layer(dim) self.mlp = MLP([dim, int(dim * mlp_ratio), dim], act_fn=act_fn)
def forward(self, x: torch.Tensor) -> torch.Tensor: # ViT-like structure # attention x = x + self.attn(self.norm1(x)) # mlp x = x + self.mlp(self.norm2(x)) return x
[docs]class DiTransolverBlock(TransolverBlock):
[docs] def __init__(self, cond_dim: int, *args, **kwargs): super().__init__(*args, **kwargs) self.dit = DiT(self.dim, cond_dim)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: # x = super().forward(x) scale1, shift1, gate1, scale2, shift2, gate2 = self.dit(cond) # modulated attention x1 = self.dit.modulate_scale_shift(self.norm1(x), scale1, shift1) x2 = self.dit.modulate_gate(self.attn(x1), gate1) + x # modulated mlp x3 = self.dit.modulate_scale_shift(self.norm2(x2), scale2, shift2) x4 = self.dit.modulate_gate(self.mlp(x3), gate2) + x2 return x4
[docs]@register_model() class Transolver(nn.Module): """ Transolver: A Transformer-based PDE solver, with Physics-Attention to model physical states and efficiently capture complex geometries. Args: n_conds (int): Number of conditioning inputs. latent_channels (int): Channels in the latent representation. output_channels (int): Number of output channels. act_fn (nn.Module): Activation function to use. dropout_prob (float): Dropout probability. attn_dropout_prob (float): Dropout for attention layers. space (int): Spatial dimensionality (e.g., 2D or 3D). transolver_base (int): Base feature size for the transformer. num_heads (int): Number of attention heads. num_layers (int): Number of transformer layers. slice_base (int): Base number of tokens for slicing. mlp_ratio (float): Ratio for MLP hidden dimension. conditioning_mode (str): Type of conditioning ('dit', etc.). out_deformation (bool): Whether to output deformation fields. n_materials (Optional[int]): Number of materials (if applicable). Paper: https://arxiv.org/abs/2402.02366 """
[docs] def __init__( self, n_conds: int, latent_channels: int = 8, output_channels: int = 17, act_fn: nn.Module = nn.SiLU, dropout_prob: float = 0.1, attn_dropout_prob: float = 0.1, space: int = 2, transolver_base: int = 128, num_heads: int = 4, num_layers: int = 2, slice_base: int = 64, mlp_ratio: float = 2.0, conditioning_mode: str = "dit", out_deformation: bool = True, n_materials: Optional[int] = None, conditioning_bn: bool = False, conditioning_ln: bool = False, ): super().__init__() self.space = space self.output_channels = output_channels assert (transolver_base % num_heads) == 0 assert conditioning_mode in ["cat", "dit"] self.conditioning_mode = conditioning_mode self.out_deformation = out_deformation self.conditioning = nn.Sequential( ContinuousSincosEmbed(dim=256, ndim=n_conds), MLP( [256, 256 // 2, 256 // 4, latent_channels], act_fn=act_fn, dropout_prob=dropout_prob, batchnorm=conditioning_bn, layernorm=conditioning_ln, ), ) # encode positions to latent self.coord_embed = ContinuousSincosEmbed(dim=transolver_base, ndim=space) self.encoder = MLP([transolver_base, transolver_base], act_fn=act_fn) # material embedding if n_materials is not None: self.material_embedding = nn.Embedding( num_embeddings=n_materials, embedding_dim=transolver_base ) # processor ("physics attention") BlockType = TransolverBlock if conditioning_mode == "cat": self.proj_cond = nn.Linear( transolver_base + latent_channels, transolver_base, bias=False ) if conditioning_mode == "dit": BlockType = partial(DiTransolverBlock, latent_channels) blocks = [] for _ in range(num_layers): block = BlockType( dim=transolver_base, num_heads=num_heads, dropout_prob=dropout_prob, attn_dropout_prob=attn_dropout_prob, act_fn=act_fn, mlp_ratio=mlp_ratio, slice_base=slice_base, ) blocks.append(block) self.blocks = nn.ModuleList(blocks) # decode latent to fields (+ positions) self.decoder = MLP( [transolver_base, output_channels + (space if out_deformation else 0)], act_fn, dropout_prob=dropout_prob, ) self.reset_parameters()
def reset_parameters(self): self.apply(self._init_weights) def _init_weights(self, m): # NOTE reproduce original initialization? if isinstance(m, nn.Linear): # from timm.models.layers import trunc_normal_ nn.init.trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward( self, cond: torch.Tensor, mesh_coords: torch.Tensor, mesh_edges: torch.Tensor, mesh_material: Optional[torch.Tensor] = None, batch_index: Optional[torch.Tensor] = None, ): _ = mesh_edges # NOTE: transolver expects inputs of shape (B, N, d) pad_mask = None coords = mesh_coords.clone() # pad to max nodes if in sparse format if mesh_coords.ndim == 2: mesh_coords, pad_mask = to_dense_batch(mesh_coords, batch_index) # conditioning latent_vector = self.conditioning(cond) # encoder x = self.encoder(self.coord_embed(mesh_coords)) # (B, N, C) if mesh_material is not None: # add material embedding if we have it mesh_material, _ = to_dense_batch(mesh_material.squeeze(), batch_index) mesh_material_embedding = self.material_embedding(mesh_material.squeeze()) x += mesh_material_embedding if self.conditioning_mode == "cat": z = latent_vector[batch_index] if pad_mask is not None: z, _ = to_dense_batch(z, batch_index) x = self.proj_cond(torch.cat([x, z], dim=-1)) cond = {} if self.conditioning_mode == "dit": cond = {"cond": latent_vector} # transolver for block in self.blocks: x = block(x, **cond) # decoder x = self.decoder(x) # unpad if pad_mask is not None: x = x[pad_mask] # dynamic mesh if self.out_deformation: x, dpos = x.split([self.output_channels, self.space], -1) coords = coords + dpos return (x, coords), latent_vector