Source code for models.PointNet2

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# pointnet2
#
# Copyright (c) 2017, Geometric Computation Group of Stanford University
#
# The MIT License (MIT)
#
# Copyright (c) 2017 Charles R. Qi
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
import torch.nn as nn
import torch.nn.functional as F

from .PointNet import PointNetFeatureExtractor

import kaolin.cuda as ext
import kaolin.cuda.ball_query
import kaolin.cuda.furthest_point_sampling
import kaolin.cuda.three_nn


class FurthestPointSampling(torch.autograd.Function):
    r"""
    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }
    """

    @staticmethod
    def forward(ctx, xyz, num_points_out):
        """Uses iterative furthest point sampling to select a set of num_points_out features that have the largest minimum distance.

        Args:
            xyz (torch.Tensor): (B, N, 3) tensor where N > num_points_out
            num_points_out (int32): number of features in the sampled set

        Returns:
            (torch.Tensor): (B, num_points_out) tensor containing the set
        """
        return ext.furthest_point_sampling.furthest_point_sampling(xyz, num_points_out)

    @staticmethod
    def backward(xyz, a=None):
        return None, None


furthest_point_sampling = FurthestPointSampling.apply


class FPSGatherByIndex(torch.autograd.Function):
    r"""
    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }
    """
    @staticmethod
    def forward(ctx, features, idx):
        """TODO: documentation (and the ones below)
        Args:
            features (torch.Tensor): (B, C, N) tensor

            idx (torch.Tensor): (B, npoint) tensor of the features to gather

        Returns:
            (torch.Tensor): (B, C, npoint) tensor
        """

        _, C, N = features.size()

        ctx.for_backwards = (idx, C, N)

        return ext.furthest_point_sampling.gather_by_index(features, idx)

    @staticmethod
    def backward(ctx, grad_out):
        idx, C, N = ctx.for_backwards

        grad_features = ext.furthest_point_sampling.gather_by_index_grad(
            grad_out.contiguous(), idx, N)
        return grad_features, None


fps_gather_by_index = FPSGatherByIndex.apply


class ThreeNN(torch.autograd.Function):
    r"""
    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }
    """
    @staticmethod
    def forward(ctx, unknown, known):
        # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
        r"""
            Find the three nearest neighbors of unknown in known
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of known features
        known : torch.Tensor
            (B, m, 3) tensor of unknown features

        Returns
        -------
        dist : torch.Tensor
            (B, n, 3) l2 distance to the three nearest neighbors
        idx : torch.Tensor
            (B, n, 3) index of 3 nearest neighbors
        """
        dist2, idx = ext.three_nn.three_nn(unknown, known)

        return torch.sqrt(dist2), idx

    @staticmethod
    def backward(ctx, a=None, b=None):
        return None, None


three_nn = ThreeNN.apply


class ThreeInterpolate(torch.autograd.Function):
    r"""
    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }
    """
    @staticmethod
    def forward(ctx, features, idx, weight):
        # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
        r"""
            Performs weight linear interpolation on 3 features
        Parameters
        ----------
        features : torch.Tensor
            (B, c, m) Features descriptors to be interpolated from
        idx : torch.Tensor
            (B, n, 3) three nearest neighbors of the target features in features
        weight : torch.Tensor
            (B, n, 3) weights

        Returns
        -------
        torch.Tensor
            (B, c, n) tensor of the interpolated features
        """
        B, c, m = features.size()
        n = idx.size(1)

        ctx.three_interpolate_for_backward = (idx, weight, m)

        return ext.three_nn.three_interpolate(features, idx, weight)

    @staticmethod
    def backward(ctx, grad_out):
        # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        r"""
        Parameters
        ----------
        grad_out : torch.Tensor
            (B, c, n) tensor with gradients of ouputs

        Returns
        -------
        grad_features : torch.Tensor
            (B, c, m) tensor with gradients of features

        None

        None
        """
        idx, weight, m = ctx.three_interpolate_for_backward

        grad_features = ext.three_nn.three_interpolate_grad(
            grad_out.contiguous(), idx, weight, m
        )

        return grad_features, None, None


three_interpolate = ThreeInterpolate.apply


class GroupGatherByIndex(torch.autograd.Function):
    r"""
    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }
    """
    @staticmethod
    def forward(ctx, features, idx):
        # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""

        Parameters
        ----------
        features : torch.Tensor
            (B, C, N) tensor of features to group
        idx : torch.Tensor
            (B, npoint, nsample) tensor containing the indicies of features to group with

        Returns
        -------
        torch.Tensor
            (B, C, npoint, nsample) tensor
        """
        B, nfeatures, nsample = idx.size()
        _, C, N = features.size()

        ctx.for_backwards = (idx, N)

        return ext.ball_query.gather_by_index(features, idx)

    @staticmethod
    def backward(ctx, grad_out):
        # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
        r"""

        Parameters
        ----------
        grad_out : torch.Tensor
            (B, C, npoint, nsample) tensor of the gradients of the output from forward

        Returns
        -------
        torch.Tensor
            (B, C, N) gradient of the features
        None
        """
        idx, N = ctx.for_backwards

        grad_features = ext.ball_query.gather_by_index_grad(
            grad_out.contiguous(), idx, N)

        return grad_features, None


group_gather_by_index = GroupGatherByIndex.apply


class BallQuery(torch.autograd.Function):
    r"""
    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }
    """

    @staticmethod
    def forward(ctx, radius, nsample, xyz, new_xyz, use_random=False):
        # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""
        TODO: documentation

        Parameters
        ----------
        radius : float
            radius of the balls
        nsample : int
            maximum number of features in the balls
        xyz : torch.Tensor
            (B, N, 3) xyz coordinates of the features
        new_xyz : torch.Tensor
            (B, npoint, 3) centers of the ball query

        Returns
        -------
        torch.Tensor
            (B, npoint, nsample) tensor with the indicies of the features that form the query balls
        """
        if use_random:
            return ext.ball_query.ball_random_query(
                torch.randint(int(1e9), ()).item(), new_xyz, xyz, radius,
                nsample)

        return ext.ball_query.ball_query(new_xyz, xyz, radius, nsample)

    @staticmethod
    def backward(ctx, a=None):
        return None, None, None, None


ball_query = BallQuery.apply

# TODO: improvement: experiment with random sampling instead of current approach.


def separate_xyz_and_features(points):
    """Break up a point cloud into position vectors (first 3 dimensions) and feature vectors.

    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }

    Args:
        points (torch.Tensor): shape = (batch_size, num_points, 3 + num_features)
            The point cloud to separate.

    Returns:
        xyz (torch.Tensor): shape = (batch_size, num_points, 3)
            The position vectors of the points.
        features (torch.Tensor|None): shape = (batch_size, num_features, num_points)
            The feature vectors of the points.
            If there are no feature vectors, features will be None.
    """
    assert (len(points.shape) == 3 and points.shape[2] >= 3), (
        'Expected shape of points to be (batch_size, num_points, 3 + num_features), got {}'
        .format(points.shape))

    xyz = points[:, :, 0:3].contiguous()
    features = (points[:, :, 3:].transpose(1, 2).contiguous()
                if points.shape[2] > 3 else None)

    return xyz, features


class PointNet2GroupingLayer(nn.Module):
    """
    TODO: documentation: if radius is None, then group everything

    .. note::

        If you use this code, please cite the original paper in addition to Kaolin.
        
        .. code-block::

            @article{qi2017pointnet2,
                title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space},
                author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.},
                year = {2017},
                journal={arXiv preprint arXiv:1706.02413},
            }
    """

    def __init__(self, radius, num_samples, use_xyz_feature=True, use_random_ball_query=False):
        super(PointNet2GroupingLayer, self).__init__()
        self.radius = radius
        self.num_samples = num_samples
        self.use_xyz_feature = use_xyz_feature
        self.use_random_ball_query = use_random_ball_query

    def forward(self, xyz, new_xyz, features=None):
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            xyz coordinates of the features (B, N, 3)
        new_xyz : torch.Tensor
            centriods (B, npoint, 3)
        features : torch.Tensor
            Descriptors of the features (B, C, N)

        Returns
        -------
        new_features : torch.Tensor
            (B, 3 + C, npoint, nsample) tensor
        """

        if self.radius is None:
            grouped_xyz = xyz.transpose(1, 2)
            if features is not None:
                grouped_features = features
                if self.use_xyz_feature:
                    new_features = torch.cat(
                        [grouped_xyz, grouped_features], dim=1
                    )  # (B, 3 + C, 1, N)
                else:
                    new_features = grouped_features
            else:
                new_features = grouped_xyz

            return new_features

        else:
            idx = ball_query(self.radius, self.num_samples, xyz,
                             new_xyz, self.use_random_ball_query)
            xyz_trans = xyz.transpose(1, 2).contiguous()
            grouped_xyz = group_gather_by_index(
                xyz_trans, idx)  # (B, 3, npoint, nsample)
            grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)

            if features is not None:
                grouped_features = group_gather_by_index(features, idx)
                if self.use_xyz_feature:
                    new_features = torch.cat(
                        [grouped_xyz, grouped_features], dim=1
                    )  # (B, C + 3, npoint, nsample)
                else:
                    new_features = grouped_features
            else:
                assert self.use_xyz_feature, "Must have at least one feature or set use_xyz_feature = True"
                new_features = grouped_xyz

            return new_features.transpose(1, 2).contiguous()


[docs]class PointNet2SetAbstraction(nn.Module): """A single set-abstraction layer for the PointNet++ architecture. Supports multi-scale grouping (MSG). .. note:: If you use this code, please cite the original paper in addition to Kaolin. .. code-block:: @article{qi2017pointnet2, title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space}, author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.}, year = {2017}, journal={arXiv preprint arXiv:1706.02413}, } Args: num_points_out (int|None): The number of output points. If None, group all points together. pointnet_in_features (int): The number of features to input into pointnet. Note: if use_xyz_feature is true, this value will be increased by 3. pointnet_layer_dims_list (List[List[int]]): The pointnet MLP dimensions list for each scale. Note: the first (input) dimension SHOULD NOT be included in each list, while the last (output) dimension SHOULD be included in each list. radii_list (List[float]|None): The grouping radius for each scale. If num_points_out is None, this value is ignored. num_samples_list (List[int]|None): The number of samples in each ball query for each scale. If num_points_out is None, this value is ignored. batchnorm (bool): Whether or not to use batch normalization. use_xyz_feature (bool): Whether or not to use the coordinates of the points as feature. use_random_ball_query (bool): Whether or not to use random sampling when there are too many points per ball. """ def __init__(self, num_points_out, pointnet_in_features, pointnet_layer_dims_list, radii_list=None, num_samples_list=None, batchnorm=True, use_xyz_feature=True, use_random_ball_query=False): super(PointNet2SetAbstraction, self).__init__() # TODO: Testing: test if the model works with each of the parameters if num_points_out is None: radii_list = [None] num_samples_list = [None] else: assert isinstance(radii_list, list) and isinstance( num_samples_list, list), 'radii_list and num_samples_list must be lists' assert (len(radii_list) == len(num_samples_list) == len(pointnet_layer_dims_list)), ( 'Dimension of radii_list ({}), num_samples_list ({}), pointnet_layer_dims_list ({}) must match' .format(len(radii_list), len(num_samples_list), len(pointnet_layer_dims_list))) self.num_points_out = num_points_out self.pointnet_layer_dims_list = pointnet_layer_dims_list self.sub_modules = nn.ModuleList() self.layers = [] self.pointnet_in_channels = pointnet_in_features + \ (3 if use_xyz_feature else 0) num_scales = len(radii_list) for i in range(num_scales): radius = radii_list[i] num_samples = num_samples_list[i] pointnet_layer_dims = pointnet_layer_dims_list[i] assert isinstance(pointnet_layer_dims, list), 'Each pointnet_layer_dims must be a list, got {} instead'.format( pointnet_layer_dims) assert len( pointnet_layer_dims) > 0, 'Each pointnet_layer_dims must have at least one element' grouper = PointNet2GroupingLayer( radius, num_samples, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query) # TODO: refactor: add dropout parameters pointnet = PointNetFeatureExtractor( in_channels=self.pointnet_in_channels, feat_size=pointnet_layer_dims[-1], layer_dims=pointnet_layer_dims[:-1], global_feat=True, batchnorm=batchnorm, transposed_input=True ) # Register sub-modules self.sub_modules.append(grouper) self.sub_modules.append(pointnet) self.layers.append((grouper, pointnet, num_samples)) def forward(self, xyz, features=None): """ Args: xyz (torch.Tensor): shape = (batch_size, num_points_in, 3) The 3D coordinates of each point. features (torch.Tensor|None): shape = (batch_size, num_features, num_points_in) The features of each point. Returns: new_xyz (torch.Tensor|None): shape = (batch_size, num_points_out, 3) The new coordinates of the grouped points. If self.num_points_out is None, new_xyz will be None. new_features (torch.Tensor): shape = (batch_size, out_num_features, num_points_out) The features of each output point. If self.num_points_out is None, new_features will have shape: (batch_size, num_features_out) """ batch_size = xyz.shape[0] new_xyz = None if self.num_points_out is not None: # TODO: implement: this is flipped here for some reason new_xyz_idx = furthest_point_sampling(xyz, self.num_points_out) new_xyz = fps_gather_by_index( xyz.transpose(1, 2).contiguous(), new_xyz_idx) new_xyz = new_xyz.transpose(1, 2).contiguous() new_features_list = [] for grouper, pointnet, num_samples in self.layers: new_features = grouper(xyz, new_xyz, features) # shape = (batch_size, num_points_out, self.pointnet_in_channels, num_samples) # if num_points_out is None: # shape = (batch_size, self.pointnet_in_channels, num_samples) if self.num_points_out is not None: new_features = new_features.view(-1, self.pointnet_in_channels, num_samples) new_features = pointnet(new_features) # shape = (batch_size * num_points_out, feat_size) # if num_points_out is None: # shape = (batch_size, feat_size) # TODO: Optimization: avoid this packing and unpacking step by refactoring and generalizing pointnet if self.num_points_out is not None: new_features = new_features.view( batch_size, self.num_points_out, -1).transpose(1, 2) # shape = (batch_size, feat_size, num_points_out) new_features_list.append(new_features) new_features = torch.cat(new_features_list, dim=1) # shape = (batch_size, num_features_out, num_points_out) # if num_points_out is None: # shape = (batch_size, num_features_out) return new_xyz, new_features def get_num_features_out(self): return sum([lst[-1] for lst in self.pointnet_layer_dims_list])
[docs]class PointNet2FeaturePropagator(nn.Module): """A single feature-propagation layer for the PointNet++ architecture. Used for segmentation. .. note:: If you use this code, please cite the original paper in addition to Kaolin. .. code-block:: @article{qi2017pointnet2, title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space}, author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.}, year = {2017}, journal={arXiv preprint arXiv:1706.02413}, } Args: num_features (int): The number of features in the current layer. Note: this is the number of output features of the corresponding set abstraction layer. num_features_prev (int): The number of features from the previous feature propagation layer (corresponding to the next layer during feature extraction). Note: this is the number of output features of the previous feature propagation layer (or the number of output features of the final set abstraction layer, if this is the very first feature propagation layer) layer_dims (List[int]): Sizes of the MLP layer. Note: the first (input) dimension SHOULD NOT be included in the list, while the last (output) dimension SHOULD be included in the list. batchnorm (bool): Whether or not to use batch normalization. """ def __init__(self, num_features, num_features_prev, layer_dims, batchnorm=True): super(PointNet2FeaturePropagator, self).__init__() self.layer_dims = layer_dims unit_pointnets = [] in_features = num_features + num_features_prev for out_features in layer_dims: unit_pointnets.append( nn.Conv1d(in_features, out_features, 1)) if batchnorm: unit_pointnets.append(nn.BatchNorm1d(out_features)) unit_pointnets.append(nn.ReLU()) in_features = out_features self.unit_pointnet = nn.Sequential(*unit_pointnets) def forward(self, xyz, xyz_prev, features=None, features_prev=None): """ Args: xyz (torch.Tensor): shape = (batch_size, num_points, 3) The 3D coordinates of each point at current layer, computed during feature extraction (i.e. set abstraction). xyz_prev (torch.Tensor|None): shape = (batch_size, num_points_prev, 3) The 3D coordinates of each point from the previous feature propagation layer (corresponding to the next layer during feature extraction). This value can be None (i.e. for the very first propagator layer). features (torch.Tensor|None): shape = (batch_size, num_features, num_points) The features of each point at current layer, computed during feature extraction (i.e. set abstraction). features_prev (torch.Tensor|None): shape = (batch_size, num_features_prev, num_points_prev) The features of each point from the previous feature propagation layer (corresponding to the next layer during feature extraction). Returns: (torch.Tensor): shape = (batch_size, num_features_out, num_points) """ num_points = xyz.shape[1] if xyz_prev is None: # Very first feature propagation layer new_features = features_prev.expand( *(features.shape + [num_points])) else: dist, idx = three_nn(xyz, xyz_prev) # shape = (batch_size, num_points, 3), (batch_size, num_points, 3) inverse_dist = 1.0 / (dist + 1e-8) total_inverse_dist = torch.sum(inverse_dist, dim=2, keepdim=True) weights = inverse_dist / total_inverse_dist new_features = three_interpolate(features_prev, idx, weights) # shape = (batch_size, num_features_prev, num_points) if features is not None: new_features = torch.cat([new_features, features], dim=1) return self.unit_pointnet(new_features) def get_num_features_out(self): return self.layer_dims[-1]
[docs]class PointNet2Classifier(nn.Module): r"""PointNet++ classification network. Based on the original PointNet++ paper. .. note:: If you use this code, please cite the original paper in addition to Kaolin. .. code-block:: @article{qi2017pointnet2, title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space}, author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.}, year = {2017}, journal={arXiv preprint arXiv:1706.02413}, } Args: in_features (int): Number of features (not including xyz coordinates) in the input point cloud (default: 0). num_classes (int): Number of classes (for the classification task) (default: 2). batchnorm (bool): Whether or not to use batch normalization. (default: True) use_xyz_feature (bool): Whether or not to use the coordinates of the points as feature. use_random_ball_query (bool): Whether or not to use random sampling when there are too many points per ball. TODO: Documentation: add example """ # TODO: Implement: ssg def __init__(self, in_features=0, num_classes=2, batchnorm=True, use_xyz_feature=True, use_random_ball_query=False): super(PointNet2Classifier, self).__init__() self.set_abstractions = nn.ModuleList() self.set_abstractions.append( PointNet2SetAbstraction( num_points_out=512, pointnet_in_features=in_features, pointnet_layer_dims_list=[ [32, 32, 64], [64, 64, 128], [64, 96, 128], ], radii_list=[0.1, 0.2, 0.4], num_samples_list=[16, 32, 128], batchnorm=batchnorm, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query ) ) self.set_abstractions.append( PointNet2SetAbstraction( num_points_out=128, pointnet_in_features=self.set_abstractions[-1].get_num_features_out( ), pointnet_layer_dims_list=[ [64, 64, 128], [128, 128, 256], [128, 128, 256], ], radii_list=[0.2, 0.4, 0.8], num_samples_list=[32, 64, 128], batchnorm=batchnorm, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query ) ) self.set_abstractions.append( PointNet2SetAbstraction( num_points_out=None, pointnet_in_features=self.set_abstractions[-1].get_num_features_out( ), pointnet_layer_dims_list=[ [256, 512, 1024], ], batchnorm=batchnorm, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query ) ) final_layer_modules = [ module for module in [ nn.Linear( self.set_abstractions[-1].get_num_features_out(), 512), nn.BatchNorm1d(512) if batchnorm else None, nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), nn.BatchNorm1d(256) if batchnorm else None, nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes) ] if module is not None ] self.final_layers = nn.Sequential(*final_layer_modules) def forward(self, points): """ Args: points (torch.Tensor): shape = (batch_size, num_points, 3 + in_features) The points to classify. Returns: (torch.Tensor): shape = (batch_size, num_classes) The score of the inputs being in each class. Note: no softmax or logsoftmax will be applied. """ xyz, features = separate_xyz_and_features(points) for module in self.set_abstractions: xyz, features = module(xyz, features) return self.final_layers(features)
[docs]class PointNet2Segmenter(nn.Module): """PointNet++ classification network. .. note:: If you use this code, please cite the original paper in addition to Kaolin. .. code-block:: @article{qi2017pointnet2, title = {PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space}, author = {Qi, Charles R. and Yi, Li and Su, Hao and Guibas, Leonidas J.}, year = {2017}, journal={arXiv preprint arXiv:1706.02413}, } Args: in_features (int): Number of features (not including xyz coordinates) in the input point cloud (default: 0). num_classes (int): Number of classes (for the classification task) (default: 2). batchnorm (bool): Whether or not to use batch normalization. (default: True) use_xyz_feature (bool): Whether or not to use the coordinates of the points as feature. use_random_ball_query (bool): Whether or not to use random sampling when there are too many points per ball. TODO: Documentation: add example """ def __init__(self, in_features=0, num_classes=2, batchnorm=True, use_xyz_feature=True, use_random_ball_query=False): super(PointNet2Segmenter, self).__init__() self.set_abstractions = nn.ModuleList() self.set_abstractions.append( PointNet2SetAbstraction( num_points_out=1024, pointnet_in_features=in_features, pointnet_layer_dims_list=[ [16, 16, 32], [32, 32, 64], ], radii_list=[0.05, 0.1], num_samples_list=[16, 32], batchnorm=batchnorm, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query ) ) self.set_abstractions.append( PointNet2SetAbstraction( num_points_out=256, pointnet_in_features=self.set_abstractions[-1].get_num_features_out( ), pointnet_layer_dims_list=[ [64, 64, 128], [64, 96, 128], ], radii_list=[0.1, 0.2], num_samples_list=[16, 32], batchnorm=batchnorm, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query ) ) self.set_abstractions.append( PointNet2SetAbstraction( num_points_out=64, pointnet_in_features=self.set_abstractions[-1].get_num_features_out( ), pointnet_layer_dims_list=[ [128, 196, 256], [128, 196, 256], ], radii_list=[0.2, 0.4], num_samples_list=[16, 32], batchnorm=batchnorm, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query ) ) self.set_abstractions.append( PointNet2SetAbstraction( num_points_out=16, pointnet_in_features=self.set_abstractions[-1].get_num_features_out( ), pointnet_layer_dims_list=[ [256, 256, 512], [256, 384, 512], ], radii_list=[0.4, 0.8], num_samples_list=[16, 32], batchnorm=batchnorm, use_xyz_feature=use_xyz_feature, use_random_ball_query=use_random_ball_query ) ) self.feature_propagators = nn.ModuleList() # TODO: implement: this is different from the original paper. self.feature_propagators.append( PointNet2FeaturePropagator( num_features=self.set_abstractions[-2].get_num_features_out(), num_features_prev=self.set_abstractions[-1].get_num_features_out(), layer_dims=[512, 512], batchnorm=batchnorm, ) ) self.feature_propagators.append( PointNet2FeaturePropagator( num_features=self.set_abstractions[-3].get_num_features_out(), num_features_prev=self.feature_propagators[-1].get_num_features_out( ), layer_dims=[512, 512], batchnorm=batchnorm, ) ) self.feature_propagators.append( PointNet2FeaturePropagator( num_features=self.set_abstractions[-4].get_num_features_out(), num_features_prev=self.feature_propagators[-1].get_num_features_out( ), layer_dims=[256, 256], batchnorm=batchnorm, ) ) self.feature_propagators.append( PointNet2FeaturePropagator( num_features=in_features, num_features_prev=self.feature_propagators[-1].get_num_features_out( ), layer_dims=[128, 128], batchnorm=batchnorm, ) ) final_layer_modules = [ module for module in [ nn.Conv1d( self.feature_propagators[-1].get_num_features_out(), 128, 1), nn.BatchNorm1d(128) if batchnorm else None, nn.ReLU(), nn.Dropout(0.5), nn.Conv1d(128, num_classes, 1) ] if module is not None ] self.final_layers = nn.Sequential(*final_layer_modules) def forward(self, points): """ Args: points (torch.Tensor): shape = (batch_size, num_points, 3 + in_features) The points to perform segmentation on. Returns: (torch.Tensor): shape = (batch_size, num_points, num_classes) The score of each point being in each class. Note: no softmax or logsoftmax will be applied. """ xyz, features = separate_xyz_and_features(points) xyz_list, features_list = [xyz], [features] for module in self.set_abstractions: xyz, features = module(xyz, features) xyz_list.append(xyz) features_list.append(features) target_index = -2 for module in self.feature_propagators: features_list[target_index] = module( xyz_list[target_index], xyz_list[target_index + 1], features_list[target_index], features_list[target_index + 1]) target_index -= 1 return (self.final_layers(features_list[0]) .transpose(1, 2) .contiguous())