Source code for kaolin.models.MeshEncoder

import torch 
from torch import nn
import torch.nn.functional as F

from .SimpleGCN import SimpleGCN


[docs]class MeshEncoder(nn.Module): r"""A simple mesh encoder architecture. Takes in a polygon mesh (graph) and encodes each node feature into a compact latent code. """ def __init__(self, latent_length): super(MeshEncoder, self).__init__() self.h1 = SimpleGCN(3, 60) self.h21 = SimpleGCN(60, 60) self.h22 = SimpleGCN(60, 60) self.h23 = SimpleGCN(60, 60) self.h24 = SimpleGCN(60,120) self.h3 = SimpleGCN(120, 120) self.h4 = SimpleGCN(120, 120) self.h41 = SimpleGCN(120, 150) self.h5 = SimpleGCN(150, 200) self.h6 = SimpleGCN(200, 210) self.h7 = SimpleGCN(210, 250) self.h8 = SimpleGCN(250, 300) self.h81 = SimpleGCN(300, 300) self.h9 = SimpleGCN(300, 300) self.h10 = SimpleGCN(300, 300) self.h11 = SimpleGCN(300, 300) self.reduce = SimpleGCN(300,latent_length) def resnet( self, features, res): temp = features[:,:res.shape[1]] temp = temp + res features = torch.cat((temp,features[:,res.shape[1]:]), dim = 1) return features, features
[docs] def forward(self, positions, adj): res = positions features = F.elu(self.h1(positions, adj)) features = F.elu(self.h21(features, adj)) features = F.elu(self.h22(features, adj)) features = F.elu(self.h23(features, adj)) features = F.elu(self.h24(features, adj)) features = F.elu(self.h3(features, adj)) features = F.elu(self.h4(features, adj)) features = F.elu(self.h41(features, adj)) features = F.elu(self.h5(features, adj)) features = F.elu(self.h6(features, adj)) features = F.elu(self.h7(features, adj)) features = F.elu(self.h8(features, adj)) features = F.elu(self.h81(features, adj)) features = F.elu(self.h9(features, adj)) features = F.elu(self.h10(features, adj)) features = F.elu(self.h11(features, adj)) latent = F.elu(self.reduce(features , adj)) latent = (torch.max(latent, dim = 0)[0]) return latent