Source code for simshift.models.mesh.gnn

"""
GraphSAGE-based, Film-conditioned GNN model.

https://arxiv.org/abs/1706.02216
"""


from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch_geometric.nn as pygnn

from simshift.models.condition import ContinuousSincosEmbed, Film
from simshift.models.registry import register_model
from simshift.models.utils import MLP


[docs]class ModulatedSAGEConv(pygnn.SAGEConv): """ SAGEConv layer with modulation: extends `pygnn.SAGEConv` with FILM-conditioning. Args: cond_dim (int): Dimension of the conditioning vector. """
[docs] def __init__(self, cond_dim: int, *args, **kwargs): super().__init__(*args, **kwargs) self.modulation = Film(cond_dim, self.in_channels)
def forward( self, x: torch.Tensor, edge_index: torch.Tensor, cond: torch.Tensor, size=None ): x = self.modulation(x, cond) return super().forward(x, edge_index, size)
[docs]@register_model() class GraphSAGE(nn.Module): """ Graph-based mesh model with conditioned GraphSAGE layers. Args: n_conds (int): Conditioning input dimensionality. latent_channels (int): Dimension of latent conditioning. output_channels (int): Number of output prediction channels. act_fn (nn.Module): Activation function. dropout_prob (float): Dropout rate in MLPs. space (int): Spatial dimension of the mesh. gnn_base (int): Hidden size for GNN layers. num_layers (int): Number of GNN message-passing layers. conditioning_mode (str): 'film' or 'cat' mode. out_deformation (bool): If True, predicts mesh coordinate shifts. n_materials (Optional[int]): Number of material types (if used). """
[docs] def __init__( self, n_conds: int, latent_channels: int = 256, output_channels: int = 17, act_fn: nn.Module = nn.SiLU, dropout_prob: float = 0.1, space: int = 2, gnn_base: int = 64, num_layers: int = 5, conditioning_mode: str = "film", out_deformation: bool = True, n_materials: Optional[int] = None, conditioning_bn: bool = False, conditioning_ln: bool = False, ): super().__init__() self.space = space self.output_channels = output_channels assert conditioning_mode in ["cat", "film"] self.conditioning_mode = conditioning_mode self.out_deformation = out_deformation self.activation = act_fn() self.conditioning = nn.Sequential( ContinuousSincosEmbed(dim=256, ndim=n_conds), MLP( [256, 256 // 2, 256 // 4, latent_channels], act_fn=act_fn, dropout_prob=dropout_prob, batchnorm=conditioning_bn, layernorm=conditioning_ln, ), ) # encode positions to latent self.coord_embed = ContinuousSincosEmbed(dim=gnn_base, ndim=space) self.encoder = MLP([gnn_base, gnn_base], act_fn=act_fn) # material embedding if n_materials is not None: self.material_embedding = nn.Embedding( num_embeddings=n_materials, embedding_dim=gnn_base ) # message passing processor MPBlockType = pygnn.SAGEConv if conditioning_mode == "cat": self.proj_cond = nn.Linear(latent_channels + gnn_base, gnn_base, bias=False) if conditioning_mode == "film": # node modulation layer before message passing MPBlockType = partial(ModulatedSAGEConv, latent_channels) gnn_layers = [] for _ in range(num_layers): gconv = MPBlockType(gnn_base, gnn_base, aggr="mean") gnn_layers.append(gconv) self.processor = nn.ModuleList(gnn_layers) # decode latent to fields + positions self.decoder = MLP( [gnn_base, output_channels + (space if out_deformation else 0)], act_fn, )
def forward( self, cond: torch.Tensor, mesh_coords: torch.Tensor, mesh_edges: torch.Tensor, mesh_material: Optional[torch.Tensor] = None, batch_index: Optional[torch.Tensor] = None, ): latent_vector = self.conditioning(cond) z = latent_vector[batch_index] # (BxN, C) # encoder coords = mesh_coords.clone() x = self.encoder(self.coord_embed(mesh_coords)) # (BxN, C) if mesh_material is not None: # add material embedding if we have it mesh_material_embedding = self.material_embedding(mesh_material.squeeze()) x += mesh_material_embedding if self.conditioning_mode == "cat": x = self.proj_cond(torch.cat([x, z], dim=-1)) cond = {} if self.conditioning_mode == "film": cond = {"cond": z} # message passing layers for layer in self.processor: x = layer(x, edge_index=mesh_edges, **cond) x = self.activation(x) # decoder x = self.decoder(x) if self.out_deformation: x, dpos = x.split([self.output_channels, self.space], -1) coords = coords + dpos return (x, coords), latent_vector