Source code for kaolin.datasets.shrec

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import os
import glob
from torch.utils.data import Dataset


from kaolin.rep import TriangleMesh


[docs]class SHREC16(Dataset): r"""Class to help in loading the SHREC16 dataset. SHREC16 is the dataset used for the "Large-scale 3D shape retrieval from ShapeNet Core55" contest at Eurographics 2016. More details about the challenge and the dataset are available `here <https://shapenet.cs.stanford.edu/shrec16/>`_. Args: root (str): Path to the root directory of the dataset. categories (list): List of categories to load (each class is specified as a string, and must be a valid `SHREC16` category). mode (str, choices=['train', 'test']): Whether to load the 'train' split or the 'test' split Returns: dict: Dictionary with keys: 'vertices' : vertices , 'faces' : faces """ def __init__(self, root: str, categories: list = ['alien'], mode: list = 'train'): super(SHREC16, self).__init__() if mode not in ['train', 'test']: raise ValueError('Argument \'mode\' must be one of \'train\'' 'or \'test\'. Got {0} instead.'.format(mode)) VALID_CATEGORIES = [ 'alien', 'ants', 'armadillo', 'bird1', 'bird2', 'camel', 'cat', 'centaur', 'dinosaur', 'dino_ske', 'dog1', 'dog2', 'flamingo', 'glasses', 'gorilla', 'hand', 'horse', 'lamp', 'laptop', 'man', 'myScissor', 'octopus', 'pliers', 'rabbit', 'santa', 'shark', 'snake', 'spiders', 'two_balls', 'woman' ] for category in categories: if category not in VALID_CATEGORIES: raise ValueError(f'Specified category {category} is not valid. ' 'Valid categories are {VALID_CATEGORIES}') self.mode = mode self.root = root self.categories = categories self.num_samples = 0 self.paths = [] self.categories = [] for cl in self.categories: clsdir = os.path.join(root, cl, self.mode) cur = glob.glob(clsdir + '/*.obj') self.paths = self.paths + cur self.categories += [cl] * len(cur) self.num_samples += len(cur) if len(cur) == 0: raise RuntimeWarning('No .obj files could be read ' f'for category \'{cl}\'. Skipping...') def __len__(self): """Returns the length of the dataset. """ return self.num_samples def __getitem__(self, idx): """Returns the sample at index idx. """ # Read in the list of vertices and faces # from the obj file. obj_location = self.paths[idx] mesh = TriangleMesh.from_obj(obj_location) category = self.categories[idx] # Return these tensors as a dictionary. data = dict() attributes = dict() data['vertices'] = mesh.vertices data['faces'] = mesh.faces attributes['rep'] = 'Mesh' attributes['name'] = obj_location attributes['class']: cagetory return {'attributes': attributes, 'data': data}