Source code for models.VoxelSuperresSimple

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class EncoderDecoder(nn.Module): """A simple encoder-decoder style voxel superresolution network""" def __init__(self): super(EncoderDecoder, self).__init__() self.conv1 = nn.Conv3d(1, 16, 3, stride=2, padding=1) self.bn1 = nn.BatchNorm3d(16) self.conv2 = nn.Conv3d(16, 32, 3, stride=2, padding=1) self.bn2 = nn.BatchNorm3d(32) self.deconv3 = nn.ConvTranspose3d(32, 16, 3, stride=2, padding=1) self.bn3 = nn.BatchNorm3d(16) self.deconv4 = nn.ConvTranspose3d(16, 8, 3, stride=2, padding=0) self.deconv5 = nn.ConvTranspose3d(8, 1, 3, stride=2, padding=0) def forward(self, x): # Encoder x = (F.relu(self.bn1(self.conv1(x)))) x = (F.relu(self.bn2(self.conv2(x)))) # Decoder x = F.relu(self.bn3(self.deconv3(x))) x = F.relu(self.deconv4(x)) # Superres layer return self.deconv5(x)