# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class SuperresNetwork(nn.Module):
"""TODO: Add docstring.
https://arxiv.org/abs/1802.09987
Input shape: B x 128 x 128 x 128
Output shape: B x (high//low * 128) x (high//low * 128) x (high//low * 128)
.. note::
If you use this code, please cite the original paper in addition to Kaolin.
.. code-block::
@incollection{ODM,
title = {Multi-View Silhouette and Depth Decomposition for High Resolution 3D Object Representation},
author = {Smith, Edward and Fujimoto, Scott and Meger, David},
booktitle = {Advances in Neural Information Processing Systems 31},
editor = {S. Bengio and H. Wallach and H. Larochelle and K. Grauman and N. Cesa-Bianchi and R. Garnett},
pages = {6479--6489},
year = {2018},
publisher = {Curran Associates, Inc.},
url = {http://papers.nips.cc/paper/7883-multi-view-silhouette-and-depth-decomposition-for-high-resolution-3d-object-representation.pdf}
}
"""
def __init__(self, high, low):
super(SuperresNetwork, self).__init__()
self.ratio = high // low
self.layer1 = nn.Sequential(
nn.Conv2d(6, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128))
self.inner_convs_1 = nn.ModuleList([
nn.Conv2d(128, 128, kernel_size=3, padding=1) for i in range(16)])
self.inner_bns_1 = nn.ModuleList(
[nn.BatchNorm2d(128) for i in range(16)])
self.inner_convs_2 = nn.ModuleList([
nn.Conv2d(128, 128, kernel_size=3, padding=1) for i in range(16)])
self.inner_bns_2 = nn.ModuleList([
nn.BatchNorm2d(128) for i in range(16)])
self.layer2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
)
sub_list = [nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.PixelShuffle(2)]
i = 0
for i in range(int(math.log(self.ratio, 2)) - 1):
sub_list.append(nn.Conv2d(32, 128, kernel_size=3, padding=1))
sub_list.append(nn.PixelShuffle(2))
self.sub_list = nn.ModuleList(sub_list)
self.layer3 = nn.Sequential(
nn.Conv2d(32, 6, kernel_size=1, padding=0),
)
def forward(self, x):
x = self.layer1(x)
temp = x.clone()
for i in range(16):
recall = self.inner_convs_1[i](x.clone())
recall = self.inner_bns_1[i](recall)
recall = F.relu(recall)
recall = self.inner_convs_2[i](recall)
recall = self.inner_bns_2[i](recall)
recall = recall + temp
temp = recall.clone()
recall = self.layer2(recall)
x = x + recall
for i in range(int(math.log(self.ratio, 2))):
x = self.sub_list[2 * i](x)
x = self.sub_list[2 * i + 1](x)
x = self.layer3(x)
x = torch.sigmoid(x)
return x