Models
Conditioning utils
GNN
GraphSAGE-based, Film-conditioned GNN model.
https://arxiv.org/abs/1706.02216
- class simshift.models.mesh.gnn.ModulatedSAGEConv(*args: Any, **kwargs: Any)[source]
SAGEConv layer with modulation: extends pygnn.SAGEConv with FILM-conditioning.
- Parameters:
cond_dim (int) – Dimension of the conditioning vector.
- class simshift.models.mesh.gnn.GraphSAGE(*args: Any, **kwargs: Any)[source]
Graph-based mesh model with conditioned GraphSAGE layers.
- Parameters:
n_conds (int) – Conditioning input dimensionality.
latent_channels (int) – Dimension of latent conditioning.
output_channels (int) – Number of output prediction channels.
act_fn (nn.Module) – Activation function.
dropout_prob (float) – Dropout rate in MLPs.
space (int) – Spatial dimension of the mesh.
gnn_base (int) – Hidden size for GNN layers.
num_layers (int) – Number of GNN message-passing layers.
conditioning_mode (str) – ‘film’ or ‘cat’ mode.
out_deformation (bool) – If True, predicts mesh coordinate shifts.
n_materials (Optional[int]) – Number of material types (if used).
- __init__(n_conds: int, latent_channels: int = 256, output_channels: int = 17, act_fn: torch.nn.Module = torch.nn.SiLU, dropout_prob: float = 0.1, space: int = 2, gnn_base: int = 64, num_layers: int = 5, conditioning_mode: str = 'film', out_deformation: bool = True, n_materials: int | None = None, conditioning_bn: bool = False, conditioning_ln: bool = False)[source]
PointNet
PointNet conditional model for mesh modeling.
https://arxiv.org/abs/1612.00593
- class simshift.models.mesh.pointnet.PointNet(*args: Any, **kwargs: Any)[source]
Conditional PointNet architecture.
- Parameters:
n_conds (int) – Dimension of conditioning inputs.
latent_channels (int) – Latent embedding size.
output_channels (int) – Number of output channels.
act_fn (nn.Module) – Activation function.
dropout_prob (float) – Dropout rate.
space (int) – Spatial dimension.
pointnet_base (int) – Base feature size for PointNet layers.
out_deformation (bool) – Predict coordinate deformations if True.
n_materials (Optional[int]) – Number of material types (if used).
- __init__(n_conds: int, latent_channels: int = 128, output_channels: int = 17, act_fn: torch.nn.Module = torch.nn.SiLU, dropout_prob: float = 0.1, space: int = 2, pointnet_base: int = 8, out_deformation: bool = True, n_materials: int | None = None, conditioning_bn: bool = False, conditioning_ln: bool = False)[source]
Transolver
Transolver: A Fast Transformer Solver for PDEs on General Geometries
Adapted from https://github.com/thuml/Transolver
- class simshift.models.mesh.transolver.TransolverAttention(*args: Any, **kwargs: Any)[source]
Multi-head self-attention with physics-aware slicing for PDE solvers.
- Parameters:
dim (int) – Input feature dimension.
num_heads (int) – Number of attention heads.
dropout_prob (float) – Dropout probability after attention.
attn_dropout_prob (float) – Dropout within attention weights.
slice_base (int) – Base number of slices for token grouping.
- class simshift.models.mesh.transolver.Transolver(*args: Any, **kwargs: Any)[source]
Transolver: A Transformer-based PDE solver, with Physics-Attention to model physical states and efficiently capture complex geometries.
- Parameters:
n_conds (int) – Number of conditioning inputs.
latent_channels (int) – Channels in the latent representation.
output_channels (int) – Number of output channels.
act_fn (nn.Module) – Activation function to use.
dropout_prob (float) – Dropout probability.
attn_dropout_prob (float) – Dropout for attention layers.
space (int) – Spatial dimensionality (e.g., 2D or 3D).
transolver_base (int) – Base feature size for the transformer.
num_heads (int) – Number of attention heads.
num_layers (int) – Number of transformer layers.
slice_base (int) – Base number of tokens for slicing.
mlp_ratio (float) – Ratio for MLP hidden dimension.
conditioning_mode (str) – Type of conditioning (‘dit’, etc.).
out_deformation (bool) – Whether to output deformation fields.
n_materials (Optional[int]) – Number of materials (if applicable).
Paper: https://arxiv.org/abs/2402.02366
- __init__(n_conds: int, latent_channels: int = 8, output_channels: int = 17, act_fn: torch.nn.Module = torch.nn.SiLU, dropout_prob: float = 0.1, attn_dropout_prob: float = 0.1, space: int = 2, transolver_base: int = 128, num_heads: int = 4, num_layers: int = 2, slice_base: int = 64, mlp_ratio: float = 2.0, conditioning_mode: str = 'dit', out_deformation: bool = True, n_materials: int | None = None, conditioning_bn: bool = False, conditioning_ln: bool = False)[source]