"""
PointNet conditional model for mesh modeling.
https://arxiv.org/abs/1612.00593
"""
from typing import Optional
import torch
import torch.nn as nn
from simshift.models.condition import ContinuousSincosEmbed
from simshift.models.registry import register_model
from simshift.models.utils import MLP
[docs]@register_model()
class PointNet(nn.Module):
"""
Conditional PointNet architecture.
Args:
n_conds (int): Dimension of conditioning inputs.
latent_channels (int): Latent embedding size.
output_channels (int): Number of output channels.
act_fn (nn.Module): Activation function.
dropout_prob (float): Dropout rate.
space (int): Spatial dimension.
pointnet_base (int): Base feature size for PointNet layers.
out_deformation (bool): Predict coordinate deformations if True.
n_materials (Optional[int]): Number of material types (if used).
"""
[docs] def __init__(
self,
n_conds: int,
latent_channels: int = 128,
output_channels: int = 17,
act_fn: nn.Module = nn.SiLU,
dropout_prob: float = 0.1,
space: int = 2,
pointnet_base: int = 8,
out_deformation: bool = True,
n_materials: Optional[int] = None,
conditioning_bn: bool = False,
conditioning_ln: bool = False,
):
super().__init__()
self.space = space
self.output_channels = output_channels
self.out_deformation = out_deformation
self.conditioning = nn.Sequential(
ContinuousSincosEmbed(dim=256, ndim=n_conds),
MLP(
[256, 256 // 2, 256 // 4, latent_channels],
act_fn=act_fn,
dropout_prob=dropout_prob,
batchnorm=conditioning_bn,
layernorm=conditioning_ln,
),
)
# encode positions to latent
self.coord_embed = ContinuousSincosEmbed(dim=latent_channels, ndim=space)
self.encoder = MLP(
[latent_channels, latent_channels], act_fn=act_fn, dropout_prob=dropout_prob
)
# material embedding
if n_materials is not None:
self.material_embedding = nn.Embedding(
num_embeddings=n_materials, embedding_dim=latent_channels
)
# pointnet processor
self.in_block = MLP(
[latent_channels, pointnet_base, pointnet_base * 2],
act_fn=act_fn,
dropout_prob=dropout_prob,
)
self.max_block = MLP(
[
pointnet_base * 2,
pointnet_base * 4,
pointnet_base * 8,
pointnet_base * 32,
],
act_fn=act_fn,
dropout_prob=dropout_prob,
)
self.out_block = MLP(
[
pointnet_base * (32 + 2) + latent_channels, # (globals + locals + cond)
pointnet_base * 16,
pointnet_base * 8,
pointnet_base * 4,
latent_channels,
],
act_fn=act_fn,
dropout_prob=dropout_prob,
)
self.decoder = MLP(
[latent_channels, output_channels + (space if out_deformation else 0)],
act_fn,
dropout_prob=dropout_prob,
)
def forward(
self,
cond: torch.Tensor,
mesh_coords: torch.Tensor,
mesh_edges: torch.Tensor,
mesh_material: Optional[torch.Tensor] = None,
batch_index: Optional[torch.Tensor] = None,
):
_ = mesh_edges
latent_vector = self.conditioning(cond)
# encoder
coords = mesh_coords.clone()
x = self.encoder(self.coord_embed(mesh_coords)) # (BxN, C)
if mesh_material is not None:
# add material embedding if we have it
mesh_material_embedding = self.material_embedding(mesh_material.squeeze())
x += mesh_material_embedding
x = self.in_block(x)
x_m = self.max_block(x)
bs = batch_index.max() + 1
global_x = torch.zeros((bs, x_m.shape[-1]), device=x.device, dtype=x.dtype)
global_x.scatter_reduce_(
0, batch_index.unsqueeze(1).repeat(1, x_m.shape[-1]), x_m, reduce="amax"
) # (B, C)
# count points per mesh
points = torch.zeros(bs, device=x.device, dtype=x.dtype)
ones = torch.ones_like(batch_index, device=x.device, dtype=x.dtype)
points.scatter_reduce_(0, batch_index, ones, reduce="sum")
points = points.long()
# concatenate conditioning to globals
global_x = torch.cat([global_x, latent_vector], dim=-1) # (B, gC+C)
global_x = torch.repeat_interleave(global_x, points, dim=0)
x = torch.cat([x, global_x], dim=1) # (BxN, lC+gC+C)
x = self.out_block(x) # (BxN, C)
# decoder
x = self.decoder(x)
if self.out_deformation:
x, dpos = x.split([self.output_channels, self.space], -1)
coords = coords + dpos
return (x, coords), latent_vector