Source code for kaolin.models.OccupancyNetwork

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer,
# Andreas Geiger, Sebastian Nowozin

# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import math

import torch 
from torch import nn 
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torchvision import models
import torch.distributions as dist

import torch
from torch.nn import Parameter

[docs]class Resnet18(nn.Module): r''' ResNet-18 encoder network for image input. Args: c_dim (int): output dimension of the latent embedding normalize (bool): whether the input images should be normalized use_linear (bool): whether a final linear layer should be used ''' def __init__(self, c_dim, normalize=True, use_linear=True): super().__init__() self.normalize = normalize self.use_linear = use_linear self.features = models.resnet18(pretrained=True) self.features.fc = nn.Sequential() if use_linear: self.fc = nn.Linear(512, c_dim) elif c_dim == 512: self.fc = nn.Sequential() else: raise ValueError('c_dim must be 512 if use_linear is False')
[docs] def forward(self, x): if self.normalize: x = normalize_imagenet(x) net = self.features(x) out = self.fc(net) return out
[docs]def normalize_imagenet(x): ''' Normalize input images according to ImageNet standards. Args: x (tensor): input images ''' x = x.clone() x[:, 0] = (x[:, 0] - 0.485) / 0.229 x[:, 1] = (x[:, 1] - 0.456) / 0.224 x[:, 2] = (x[:, 2] - 0.406) / 0.225 return x
[docs]class DecoderCBatchNorm(nn.Module): ''' Decoder with conditional batch normalization (CBN) class. Args: dim (int): input dimension z_dim (int): dimension of latent code z c_dim (int): dimension of latent conditioned code c hidden_size (int): hidden size of Decoder network leaky (bool): whether to use leaky ReLUs legacy (bool): whether to use the legacy structure ''' def __init__(self, dim=3, z_dim=128, c_dim=128, hidden_size=256, leaky=False, legacy=False): super().__init__() self.z_dim = z_dim if not z_dim == 0: self.fc_z = nn.Linear(z_dim, hidden_size) self.fc_p = nn.Conv1d(dim, hidden_size, 1) self.block0 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) self.block1 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) self.block2 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) self.block3 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) self.block4 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy) if not legacy: self.bn = CBatchNorm1d(c_dim, hidden_size) else: self.bn = CBatchNorm1d_legacy(c_dim, hidden_size) self.fc_out = nn.Conv1d(hidden_size, 1, 1) if not leaky: self.actvn = F.relu else: self.actvn = lambda x: F.leaky_relu(x, 0.2)
[docs] def forward(self, p, z, c, **kwargs): p = p.transpose(1, 2) batch_size, D, T = p.size() net = self.fc_p(p) if self.z_dim != 0: net_z = self.fc_z(z).unsqueeze(2) net = net + net_z net = self.block0(net, c) net = self.block1(net, c) net = self.block2(net, c) net = self.block3(net, c) net = self.block4(net, c) out = self.fc_out(self.actvn(self.bn(net, c))) out = out.squeeze(1) return out
[docs]def get_prior_z(device): ''' Returns prior distribution for latent code z. Args: cfg (dict): imported yaml config device (device): pytorch device ''' z_dim = 0 p0_z = dist.Normal( torch.zeros(z_dim, device = device), torch.ones(z_dim, device = device) ) return p0_z
[docs]class CBatchNorm1d(nn.Module): ''' Conditional batch normalization layer class. Args: c_dim (int): dimension of latent conditioned code c f_dim (int): feature dimension norm_method (str): normalization method ''' def __init__(self, c_dim, f_dim, norm_method='batch_norm'): super().__init__() self.c_dim = c_dim self.f_dim = f_dim self.norm_method = norm_method # Submodules self.conv_gamma = nn.Conv1d(c_dim, f_dim, 1) self.conv_beta = nn.Conv1d(c_dim, f_dim, 1) if norm_method == 'batch_norm': self.bn = nn.BatchNorm1d(f_dim, affine=False) elif norm_method == 'instance_norm': self.bn = nn.InstanceNorm1d(f_dim, affine=False) elif norm_method == 'group_norm': self.bn = nn.GroupNorm1d(f_dim, affine=False) else: raise ValueError('Invalid normalization method!') self.reset_parameters() def reset_parameters(self): nn.init.zeros_(self.conv_gamma.weight) nn.init.zeros_(self.conv_beta.weight) nn.init.ones_(self.conv_gamma.bias) nn.init.zeros_(self.conv_beta.bias)
[docs] def forward(self, x, c): assert(x.size(0) == c.size(0)) assert(c.size(1) == self.c_dim) # c is assumed to be of size batch_size x c_dim x T if len(c.size()) == 2: c = c.unsqueeze(2) # Affine mapping gamma = self.conv_gamma(c) beta = self.conv_beta(c) # Batchnorm net = self.bn(x) out = gamma * net + beta return out
[docs]class CResnetBlockConv1d(nn.Module): ''' Conditional batch normalization-based Resnet block class. Args: c_dim (int): dimension of latend conditioned code c size_in (int): input dimension size_out (int): output dimension size_h (int): hidden dimension norm_method (str): normalization method legacy (bool): whether to use legacy blocks ''' def __init__(self, c_dim, size_in, size_h=None, size_out=None, norm_method='batch_norm', legacy=False): super().__init__() # Attributes if size_h is None: size_h = size_in if size_out is None: size_out = size_in self.size_in = size_in self.size_h = size_h self.size_out = size_out # Submodules if not legacy: self.bn_0 = CBatchNorm1d( c_dim, size_in, norm_method=norm_method) self.bn_1 = CBatchNorm1d( c_dim, size_h, norm_method=norm_method) else: self.bn_0 = CBatchNorm1d_legacy( c_dim, size_in, norm_method=norm_method) self.bn_1 = CBatchNorm1d_legacy( c_dim, size_h, norm_method=norm_method) self.fc_0 = nn.Conv1d(size_in, size_h, 1) self.fc_1 = nn.Conv1d(size_h, size_out, 1) self.actvn = nn.ReLU() if size_in == size_out: self.shortcut = None else: self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False) # Initialization nn.init.zeros_(self.fc_1.weight)
[docs] def forward(self, x, c): net = self.fc_0(self.actvn(self.bn_0(x, c))) dx = self.fc_1(self.actvn(self.bn_1(net, c))) if self.shortcut is not None: x_s = self.shortcut(x) else: x_s = x return x_s + dx
[docs]class OccupancyNetwork(nn.Module): ''' Occupancy Network class. Args: decoder (nn.Module): decoder network encoder (nn.Module): encoder network p0_z (dist): prior distribution for latent code z device (device): torch device .. note:: If you use this code, please cite the original paper in addition to Kaolin. .. code-block:: @inproceedings{Occupancy Networks, title = {Occupancy Networks: Learning 3D Reconstruction in Function Space}, author = {Mescheder, Lars and Oechsle, Michael and Niemeyer, Michael and Nowozin, Sebastian and Geiger, Andreas}, booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)}, year = {2019} } ''' def __init__(self, device): super().__init__() self.device = device self.decoder = DecoderCBatchNorm(dim=3, z_dim=0, c_dim=256, hidden_size=256).to(self.device) self.encoder = Resnet18(256, normalize=True, use_linear=True).to(self.device) self.p0_z = get_prior_z(self.device)
[docs] def forward(self, p, inputs, sample=True, **kwargs): ''' Performs a forward pass through the network. Args: p (tensor): sampled points inputs (tensor): conditioning input sample (bool): whether to sample for z ''' batch_size = p.size(0) c = self.encode_inputs(inputs) z = self.get_z_from_prior((batch_size,), sample=sample) p_r = self.decode(p, z, c, **kwargs) return p_r
[docs] def compute_elbo(self, p, occ, inputs, **kwargs): ''' Computes the expectation lower bound. Args: p (tensor): sampled points occ (tensor): occupancy values for p inputs (tensor): conditioning input ''' c = self.encode_inputs(inputs) q_z = self.infer_z(p, occ, c, **kwargs) z = q_z.rsample() p_r = self.decode(p, z, c, **kwargs) rec_error = -p_r.log_prob(occ).sum(dim=-1) kl = dist.kl_divergence(q_z, self.p0_z).sum(dim=-1) elbo = -rec_error - kl return elbo, rec_error, kl
[docs] def encode_inputs(self, inputs): ''' Encodes the input. Args: input (tensor): the input ''' c = self.encoder(inputs) return c
[docs] def decode(self, p, z, c, **kwargs): ''' Returns occupancy probabilities for the sampled points. Args: p (tensor): points z (tensor): latent code z c (tensor): latent conditioned code c ''' logits = self.decoder(p, z, c, **kwargs) p_r = dist.Bernoulli(logits=logits) return p_r
[docs] def infer_z(self, p, occ, c, **kwargs): ''' Infers z. Args: p (tensor): points tensor occ (tensor): occupancy values for occ c (tensor): latent conditioned code c ''' batch_size = p.size(0) mean_z = torch.empty(batch_size, 0).to(self.device) logstd_z = torch.empty(batch_size, 0).to(self.device) q_z = dist.Normal(mean_z, torch.exp(logstd_z)) return q_z
[docs] def get_z_from_prior(self, size=torch.Size([]), sample=True): ''' Returns z from prior distribution. Args: size (Size): size of z sample (bool): whether to sample ''' if sample: z = self.p0_z.sample(size).to(self.device) else: z = self.p0_z.mean.to(self.device) z = z.expand(*size, *z.size()) return z