Source code for kaolin.rep.SDF

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Occupancy Networks
#
# Copyright 2019 Lars Mescheder, Michael Oechsle, Michael Niemeyer, Andreas Geiger, Sebastian Nowozin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy 
# of this software and associated documentation files (the "Software"), to deal 
# in the Software without restriction, including without limitation the rights 
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 
# copies of the Software, and to permit persons to whom the Software is 
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all 
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 
# SOFTWARE.

import numpy as np
import torch

import kaolin as kal
from kaolin.triangle_hash import TriangleHash as _TriangleHash
import kaolin.cuda.mesh_intersection as mint


class MeshIntersectionFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, points: torch.Tensor, verts_1: torch.Tensor,
                verts_2: torch.Tensor, verts_3: torch.Tensor):
        batchsize, n, _ = points.size()
        points = points.contiguous()
        verts_1 = verts_1.contiguous()
        verts_2 = verts_2.contiguous()
        verts_3 = verts_3.contiguous()

        ints = torch.zeros(batchsize, n)
        ints = ints.cuda()

        mint.forward_cuda(points, verts_1, verts_2, verts_3, ints)
        ctx.save_for_backward(ints)

        return ints

    @staticmethod
    def backward(ctx, graddist1: torch.Tensor, graddist2: torch.Tensor):
        ints = ctx.saved_tensors
        gradxyz1 = torch.zeros(ints.size())
        return gradxyz1


class MeshIntersection(torch.nn.Module):
    def forward(self, points: torch.Tensor, verts_1: torch.Tensor,
                verts_2: torch.Tensor, verts_3: torch.Tensor):
        return MeshIntersectionFunction.apply(points, verts_1, verts_2,
                                              verts_3)


def check_sign_fast(mesh, points):
    intersector = MeshIntersection()
    v1 = torch.index_select(mesh.vertices, 0, mesh.faces[:, 0]).view(1, -1, 3)
    v2 = torch.index_select(mesh.vertices, 0, mesh.faces[:, 1]).view(1, -1, 3)
    v3 = torch.index_select(mesh.vertices, 0, mesh.faces[:, 2]).view(1, -1, 3)
    contains = intersector(points.view(1, -1, 3), v1, v2, v3)
    contains = contains > 0
    return contains


[docs]def check_sign(mesh, points, hash_resolution=512): r""" Checks if a set of points is contained within a mesh Args: mesh (kal.rep.Mesh): mesh to check against points (torch.Tensor): points to check hash_resolution: resolution used to check the points sign Returns: bool value for every point inciating if point is inside object Example: """ assert mesh.device == points.device if mesh.device.type == 'cuda': return check_sign_fast(mesh, points) else: intersector = _MeshIntersector(mesh, hash_resolution) contains = intersector.query(points.data.cpu().numpy()) return contains
def _length(points): return torch.sqrt(((points**2).sum(dim=1))) def sphere(r=.5): def eval_sdf(points): return _length(points) - r return eval_sdf def box(h=.2, w=.4, l=.5): def eval_sdf(points): d = torch.abs(points) d[:, 0] -= h d[:, 1] -= w d[:, 2] -= l positive_len = _length(torch.max(d, torch.zeros(d.shape).to(d.device))) negative_res = torch.max(d[:, 1], d[:, 2]) negative_res = torch.max(d[:, 0], negative_res) negative_res = torch.min(negative_res, torch.zeros( negative_res.shape).to(d.device)) positive_len = positive_len + negative_res return positive_len return eval_sdf class _MeshIntersector: r"""Class to determine if a point in space lies within our outside a mesh. """ def __init__(self, mesh, resolution=512): triangles = mesh.vertices.data.cpu().numpy( )[mesh.faces.data.cpu().numpy()].astype(np.float64) n_tri = triangles.shape[0] self.resolution = resolution self.bbox_min = triangles.reshape(3 * n_tri, 3).min(axis=0) self.bbox_max = triangles.reshape(3 * n_tri, 3).max(axis=0) # Tranlate and scale it to [0.5, self.resolution - 0.5]^3 self.scale = (resolution - 1) / (self.bbox_max - self.bbox_min) self.translate = 0.5 - self.scale * self.bbox_min self._triangles = triangles = self.rescale(triangles) triangles2d = triangles[:, :, :2] self._tri_intersector2d = _TriangleIntersector2d( triangles2d, resolution) def query(self, points): # Rescale points points = self.rescale(points) # placeholder result with no hits we'll fill in later contains = np.zeros(len(points), dtype=np.bool) # cull points outside of the axis aligned bounding box # this avoids running ray tests unless points are close inside_aabb = np.all( (0 <= points) & (points <= self.resolution), axis=1) if not inside_aabb.any(): return contains # Only consider points inside bounding box mask = inside_aabb points = points[mask] # Compute intersection depth and check order points_indices, tri_indices = self._tri_intersector2d.query( points[:, :2]) triangles_intersect = self._triangles[tri_indices] points_intersect = points[points_indices] depth_intersect, abs_n_2 = self.compute_intersection_depth( points_intersect, triangles_intersect) # Count number of intersections in both directions smaller_depth = depth_intersect >= points_intersect[:, 2] * abs_n_2 bigger_depth = depth_intersect < points_intersect[:, 2] * abs_n_2 points_indices_0 = points_indices[smaller_depth] points_indices_1 = points_indices[bigger_depth] nintersect0 = np.bincount(points_indices_0, minlength=points.shape[0]) nintersect1 = np.bincount(points_indices_1, minlength=points.shape[0]) # Check if point contained in mesh contains1 = (np.mod(nintersect0, 2) == 1) contains2 = (np.mod(nintersect1, 2) == 1) # if (contains1 != contains2).any(): # print('Warning: contains1 != contains2 for some points.') contains[mask] = (contains1 & contains2) return contains def compute_intersection_depth(self, points, triangles): t1 = triangles[:, 0, :] t2 = triangles[:, 1, :] t3 = triangles[:, 2, :] v1 = t3 - t1 v2 = t2 - t1 # v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) # v2 = v2 / np.linalg.norm(v2, axis=-1, keepdims=True) normals = np.cross(v1, v2) alpha = np.sum(normals[:, :2] * (t1[:, :2] - points[:, :2]), axis=1) n_2 = normals[:, 2] t1_2 = t1[:, 2] s_n_2 = np.sign(n_2) abs_n_2 = np.abs(n_2) mask = (abs_n_2 != 0) depth_intersect = np.full(points.shape[0], np.nan) depth_intersect[mask] = \ t1_2[mask] * abs_n_2[mask] + alpha[mask] * s_n_2[mask] # Test the depth: # TODO: remove and put into tests # points_new = np.concatenate([points[:, :2], depth_intersect[:, None]], axis=1) # alpha = (normals * t1).sum(-1) # mask = (depth_intersect == depth_intersect) # assert(np.allclose((points_new[mask] * normals[mask]).sum(-1), # alpha[mask])) return depth_intersect, abs_n_2 def rescale(self, array): array = self.scale * array + self.translate return array class _TriangleIntersector2d: def __init__(self, triangles, resolution=128): self.triangles = triangles self.tri_hash = _TriangleHash(triangles, resolution) def query(self, points): point_indices, tri_indices = self.tri_hash.query(points) point_indices = np.array(point_indices, dtype=np.int64) tri_indices = np.array(tri_indices, dtype=np.int64) points = points[point_indices] triangles = self.triangles[tri_indices] mask = self.check_triangles(points, triangles) point_indices = point_indices[mask] tri_indices = tri_indices[mask] return point_indices, tri_indices def check_triangles(self, points, triangles): contains = np.zeros(points.shape[0], dtype=np.bool) A = triangles[:, :2] - triangles[:, 2:] A = A.transpose([0, 2, 1]) y = points - triangles[:, 2] detA = A[:, 0, 0] * A[:, 1, 1] - A[:, 0, 1] * A[:, 1, 0] mask = (np.abs(detA) != 0.) A = A[mask] y = y[mask] detA = detA[mask] s_detA = np.sign(detA) abs_detA = np.abs(detA) u = (A[:, 1, 1] * y[:, 0] - A[:, 0, 1] * y[:, 1]) * s_detA v = (-A[:, 1, 0] * y[:, 0] + A[:, 0, 0] * y[:, 1]) * s_detA sum_uv = u + v contains[mask] = ( (0 < u) & (u < abs_detA) & (0 < v) & (v < abs_detA) & (0 < sum_uv) & (sum_uv < abs_detA) ) return contains