# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer,
# Andreas Geiger, Sebastian Nowozin
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import math
import torch
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torchvision import models
import torch.distributions as dist
import torch
from torch.nn import Parameter
[docs]class Resnet18(nn.Module):
r''' ResNet-18 encoder network for image input.
Args:
c_dim (int): output dimension of the latent embedding
normalize (bool): whether the input images should be normalized
use_linear (bool): whether a final linear layer should be used
'''
def __init__(self, c_dim, normalize=True, use_linear=True):
super().__init__()
self.normalize = normalize
self.use_linear = use_linear
self.features = models.resnet18(pretrained=True)
self.features.fc = nn.Sequential()
if use_linear:
self.fc = nn.Linear(512, c_dim)
elif c_dim == 512:
self.fc = nn.Sequential()
else:
raise ValueError('c_dim must be 512 if use_linear is False')
[docs] def forward(self, x):
if self.normalize:
x = normalize_imagenet(x)
net = self.features(x)
out = self.fc(net)
return out
[docs]def normalize_imagenet(x):
''' Normalize input images according to ImageNet standards.
Args:
x (tensor): input images
'''
x = x.clone()
x[:, 0] = (x[:, 0] - 0.485) / 0.229
x[:, 1] = (x[:, 1] - 0.456) / 0.224
x[:, 2] = (x[:, 2] - 0.406) / 0.225
return x
[docs]class DecoderCBatchNorm(nn.Module):
''' Decoder with conditional batch normalization (CBN) class.
Args:
dim (int): input dimension
z_dim (int): dimension of latent code z
c_dim (int): dimension of latent conditioned code c
hidden_size (int): hidden size of Decoder network
leaky (bool): whether to use leaky ReLUs
legacy (bool): whether to use the legacy structure
'''
def __init__(self, dim=3, z_dim=128, c_dim=128,
hidden_size=256, leaky=False, legacy=False):
super().__init__()
self.z_dim = z_dim
if not z_dim == 0:
self.fc_z = nn.Linear(z_dim, hidden_size)
self.fc_p = nn.Conv1d(dim, hidden_size, 1)
self.block0 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
self.block1 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
self.block2 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
self.block3 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
self.block4 = CResnetBlockConv1d(c_dim, hidden_size, legacy=legacy)
if not legacy:
self.bn = CBatchNorm1d(c_dim, hidden_size)
else:
self.bn = CBatchNorm1d_legacy(c_dim, hidden_size)
self.fc_out = nn.Conv1d(hidden_size, 1, 1)
if not leaky:
self.actvn = F.relu
else:
self.actvn = lambda x: F.leaky_relu(x, 0.2)
[docs] def forward(self, p, z, c, **kwargs):
p = p.transpose(1, 2)
batch_size, D, T = p.size()
net = self.fc_p(p)
if self.z_dim != 0:
net_z = self.fc_z(z).unsqueeze(2)
net = net + net_z
net = self.block0(net, c)
net = self.block1(net, c)
net = self.block2(net, c)
net = self.block3(net, c)
net = self.block4(net, c)
out = self.fc_out(self.actvn(self.bn(net, c)))
out = out.squeeze(1)
return out
[docs]def get_prior_z(device):
''' Returns prior distribution for latent code z.
Args:
cfg (dict): imported yaml config
device (device): pytorch device
'''
z_dim = 0
p0_z = dist.Normal(
torch.zeros(z_dim, device = device),
torch.ones(z_dim, device = device)
)
return p0_z
[docs]class CBatchNorm1d(nn.Module):
''' Conditional batch normalization layer class.
Args:
c_dim (int): dimension of latent conditioned code c
f_dim (int): feature dimension
norm_method (str): normalization method
'''
def __init__(self, c_dim, f_dim, norm_method='batch_norm'):
super().__init__()
self.c_dim = c_dim
self.f_dim = f_dim
self.norm_method = norm_method
# Submodules
self.conv_gamma = nn.Conv1d(c_dim, f_dim, 1)
self.conv_beta = nn.Conv1d(c_dim, f_dim, 1)
if norm_method == 'batch_norm':
self.bn = nn.BatchNorm1d(f_dim, affine=False)
elif norm_method == 'instance_norm':
self.bn = nn.InstanceNorm1d(f_dim, affine=False)
elif norm_method == 'group_norm':
self.bn = nn.GroupNorm1d(f_dim, affine=False)
else:
raise ValueError('Invalid normalization method!')
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.conv_gamma.weight)
nn.init.zeros_(self.conv_beta.weight)
nn.init.ones_(self.conv_gamma.bias)
nn.init.zeros_(self.conv_beta.bias)
[docs] def forward(self, x, c):
assert(x.size(0) == c.size(0))
assert(c.size(1) == self.c_dim)
# c is assumed to be of size batch_size x c_dim x T
if len(c.size()) == 2:
c = c.unsqueeze(2)
# Affine mapping
gamma = self.conv_gamma(c)
beta = self.conv_beta(c)
# Batchnorm
net = self.bn(x)
out = gamma * net + beta
return out
[docs]class CResnetBlockConv1d(nn.Module):
''' Conditional batch normalization-based Resnet block class.
Args:
c_dim (int): dimension of latend conditioned code c
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
norm_method (str): normalization method
legacy (bool): whether to use legacy blocks
'''
def __init__(self, c_dim, size_in, size_h=None, size_out=None,
norm_method='batch_norm', legacy=False):
super().__init__()
# Attributes
if size_h is None:
size_h = size_in
if size_out is None:
size_out = size_in
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
if not legacy:
self.bn_0 = CBatchNorm1d(
c_dim, size_in, norm_method=norm_method)
self.bn_1 = CBatchNorm1d(
c_dim, size_h, norm_method=norm_method)
else:
self.bn_0 = CBatchNorm1d_legacy(
c_dim, size_in, norm_method=norm_method)
self.bn_1 = CBatchNorm1d_legacy(
c_dim, size_h, norm_method=norm_method)
self.fc_0 = nn.Conv1d(size_in, size_h, 1)
self.fc_1 = nn.Conv1d(size_h, size_out, 1)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
[docs] def forward(self, x, c):
net = self.fc_0(self.actvn(self.bn_0(x, c)))
dx = self.fc_1(self.actvn(self.bn_1(net, c)))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
[docs]class OccupancyNetwork(nn.Module):
''' Occupancy Network class.
Args:
decoder (nn.Module): decoder network
encoder (nn.Module): encoder network
p0_z (dist): prior distribution for latent code z
device (device): torch device
.. note::
If you use this code, please cite the original paper in addition to Kaolin.
.. code-block::
@inproceedings{Occupancy Networks,
title = {Occupancy Networks: Learning 3D Reconstruction in Function Space},
author = {Mescheder, Lars and Oechsle, Michael and Niemeyer, Michael and Nowozin, Sebastian and Geiger, Andreas},
booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)},
year = {2019}
}
'''
def __init__(self, device):
super().__init__()
self.device = device
self.decoder = DecoderCBatchNorm(dim=3, z_dim=0, c_dim=256,
hidden_size=256).to(self.device)
self.encoder = Resnet18(256, normalize=True, use_linear=True).to(self.device)
self.p0_z = get_prior_z(self.device)
[docs] def forward(self, p, inputs, sample=True, **kwargs):
''' Performs a forward pass through the network.
Args:
p (tensor): sampled points
inputs (tensor): conditioning input
sample (bool): whether to sample for z
'''
batch_size = p.size(0)
c = self.encode_inputs(inputs)
z = self.get_z_from_prior((batch_size,), sample=sample)
p_r = self.decode(p, z, c, **kwargs)
return p_r
[docs] def compute_elbo(self, p, occ, inputs, **kwargs):
''' Computes the expectation lower bound.
Args:
p (tensor): sampled points
occ (tensor): occupancy values for p
inputs (tensor): conditioning input
'''
c = self.encode_inputs(inputs)
q_z = self.infer_z(p, occ, c, **kwargs)
z = q_z.rsample()
p_r = self.decode(p, z, c, **kwargs)
rec_error = -p_r.log_prob(occ).sum(dim=-1)
kl = dist.kl_divergence(q_z, self.p0_z).sum(dim=-1)
elbo = -rec_error - kl
return elbo, rec_error, kl
[docs] def decode(self, p, z, c, **kwargs):
''' Returns occupancy probabilities for the sampled points.
Args:
p (tensor): points
z (tensor): latent code z
c (tensor): latent conditioned code c
'''
logits = self.decoder(p, z, c, **kwargs)
p_r = dist.Bernoulli(logits=logits)
return p_r
[docs] def infer_z(self, p, occ, c, **kwargs):
''' Infers z.
Args:
p (tensor): points tensor
occ (tensor): occupancy values for occ
c (tensor): latent conditioned code c
'''
batch_size = p.size(0)
mean_z = torch.empty(batch_size, 0).to(self.device)
logstd_z = torch.empty(batch_size, 0).to(self.device)
q_z = dist.Normal(mean_z, torch.exp(logstd_z))
return q_z
[docs] def get_z_from_prior(self, size=torch.Size([]), sample=True):
''' Returns z from prior distribution.
Args:
size (Size): size of z
sample (bool): whether to sample
'''
if sample:
z = self.p0_z.sample(size).to(self.device)
else:
z = self.p0_z.mean.to(self.device)
z = z.expand(*size, *z.size())
return z