from typing import Callable, Optional
import os
from collections import OrderedDict
import torch.utils.data as data
from torchvision import transforms
[docs]class ScanNet(data.Dataset):
r"""ScanNet dataset http://www.scan-net.org/
Args:
root_dir (str): Path to the base directory of the dataset.
scene_file (str): Path to file containing a list of scenes
to be loaded.
transform (callable, optional): A function/transform that takes in a PIL
image and returns a transformed version of the image (default: None).
label_transform (callable, optional): A function/transform that takes
in the target and transforms it. (default: None).
loader (callable, optional): A function to load an image given its path.
By default, ``default_loader`` is used.
color_mean (list): A list of length 3, containing the R, G, B channelwise
mean.
color_std (list): A list of length 3, containing the R, G, B channelwise
standard deviation.
load_depth (bool): Whether or not to load depth images (architectures
that use depth information need depth to be loaded).
seg_classes (string): The palette of classes that the network should
learn.
"""
def __init__(self, root_dir: str, scene_id: str,
mode: Optional[str] = 'inference',
transform: Optional[Callable] = None,
label_transform: Optional[Callable] = None,
loader: Optional[Callable] = None,
color_mean: Optional[list] = [0.,0.,0.],
color_std: Optional[list] = [1.,1.,1.],
load_depth: Optional[bool] = False,
seg_classes: Optional[str] = 'nyu40'):
self.root_dir = root_dir
self.scene_id = scene_id
self.mode = mode
self.transform = transform
self.label_transform = label_transform
self.loader = loader
self.length = 0
self.color_mean = color_mean
self.color_std = color_std
self.load_depth = load_depth
self.seg_classes = seg_classes
# color_encoding has to be initialized AFTER seg_classes
self.color_encoding = self.get_color_encoding()
if self.loader is None:
if self.load_depth is True:
self.loader = self.scannet_loader_depth
else:
self.loader = self.scannet_loader
# Get test data and labels filepaths
self.data, self.depth, self.labels = get_filenames_scannet(
self.root_dir, self.scene_id)
self.length += len(self.data)
def __getitem__(self, index):
""" Returns element at index in the dataset.
Args:
index (``int``): index of the item in the dataset
Returns:
A tuple of ``PIL.Image`` (image, label) where label is the ground-truth of the image
"""
if self.load_depth is True:
data_path, depth_path, label_path = self.data[index], self.depth[index], self.labels[index]
rgbd, label = self.loader(data_path, depth_path, label_path, self.color_mean, self.color_std, \
self.seg_classes)
return rgbd, label, data_path, depth_path, label_path
else:
data_path, label_path = self.data[index], self.labels[index]
img, label = self.loader(data_path, label_path, self.color_mean, self.color_std, self.seg_classes)
return img, label, data_path, label_path
def __len__(self):
""" Returns the length of the dataset. """
return self.length
def get_color_encoding(self):
if self.seg_classes.lower() == 'nyu40':
"""Color palette for nyu40 labels """
return OrderedDict([
('unlabeled', (0, 0, 0)),
('wall', (174, 199, 232)),
('floor', (152, 223, 138)),
('cabinet', (31, 119, 180)),
('bed', (255, 187, 120)),
('chair', (188, 189, 34)),
('sofa', (140, 86, 75)),
('table', (255, 152, 150)),
('door', (214, 39, 40)),
('window', (197, 176, 213)),
('bookshelf', (148, 103, 189)),
('picture', (196, 156, 148)),
('counter', (23, 190, 207)),
('blinds', (178, 76, 76)),
('desk', (247, 182, 210)),
('shelves', (66, 188, 102)),
('curtain', (219, 219, 141)),
('dresser', (140, 57, 197)),
('pillow', (202, 185, 52)),
('mirror', (51, 176, 203)),
('floormat', (200, 54, 131)),
('clothes', (92, 193, 61)),
('ceiling', (78, 71, 183)),
('books', (172, 114, 82)),
('refrigerator', (255, 127, 14)),
('television', (91, 163, 138)),
('paper', (153, 98, 156)),
('towel', (140, 153, 101)),
('showercurtain', (158, 218, 229)),
('box', (100, 125, 154)),
('whiteboard', (178, 127, 135)),
('person', (120, 185, 128)),
('nightstand', (146, 111, 194)),
('toilet', (44, 160, 44)),
('sink', (112, 128, 144)),
('lamp', (96, 207, 209)),
('bathtub', (227, 119, 194)),
('bag', (213, 92, 176)),
('otherstructure', (94, 106, 211)),
('otherfurniture', (82, 84, 163)),
('otherprop', (100, 85, 144)),
])
elif self.seg_classes.lower() == 'scannet20':
return OrderedDict([
('unlabeled', (0, 0, 0)),
('wall', (174, 199, 232)),
('floor', (152, 223, 138)),
('cabinet', (31, 119, 180)),
('bed', (255, 187, 120)),
('chair', (188, 189, 34)),
('sofa', (140, 86, 75)),
('table', (255, 152, 150)),
('door', (214, 39, 40)),
('window', (197, 176, 213)),
('bookshelf', (148, 103, 189)),
('picture', (196, 156, 148)),
('counter', (23, 190, 207)),
('desk', (247, 182, 210)),
('curtain', (219, 219, 141)),
('refrigerator', (255, 127, 14)),
('showercurtain', (158, 218, 229)),
('toilet', (44, 160, 44)),
('sink', (112, 128, 144)),
('bathtub', (227, 119, 194)),
('otherfurniture', (82, 84, 163)),
])
def get_filenames_scannet(base_dir: str, scene_id: str):
"""Helper function that returns a list of scannet images and the
corresponding segmentation labels, given a base directory name
and a scene id.
Args:
base_dir (str): Path to the base directory containing ScanNet
data, in the directory structure specified in
https://github.com/angeladai/3DMV/tree/master/prepare_data
scene_id (str): ScanNet scene id
"""
if not os.path.isdir(base_dir):
raise RuntimeError('\'{0}\' is not a directory.'.format(base_dir))
color_images = []
depth_images = []
labels = []
# Explore the directory tree to get a list of all files
for path, _, files in os.walk(os.path.join(
base_dir, scene_id, 'color')):
files = natsorted(files)
for file in files:
filename, _ = os.path.splitext(file)
depthfile = os.path.join(base_dir, scene_id, 'depth',
filename + '.png')
labelfile = os.path.join(base_dir, scene_id, 'label',
filename + '.png')
# Add this file to the list of train samples, only if its
# corresponding depth and label files exist.
if os.path.exists(depthfile) and os.path.exists(labelfile):
color_images.append(os.path.join(base_dir, scene_id,
'color', filename + '.jpg'))
depth_images.append(depthfile)
labels.append(labelfile)
# Assert that we have the same number of color, depth images as labels
assert (len(color_images) == len(depth_images) == len(labels))
return color_images, depth_images, labels
def get_files(self, folder: str, name_filter: Optional[str] = None,
extension_filter: Optional[str] = None):
"""Helper function that returns the list of files in a specified folder
with a specified extension.
Args:
folder (str): The path to a folder.
name_filter (str, optional): The returned files must contain
this substring in their filename (default: None, files are
not filtered).
extension_filter (str, optional): The desired file extension
(default: None; files are not filtered).
"""
if not os.path.isdir(folder):
raise RuntimeError("\"{0}\" is not a folder.".format(folder))
# Filename filter: if not specified don't filter (condition always
# true); otherwise, use a lambda expression to filter out files that
# do not contain "name_filter"
if name_filter is None:
# This looks hackish...there is probably a better way
name_cond = lambda filename: True
else:
name_cond = lambda filename: name_filter in filename
# Extension filter: if not specified don't filter (condition always
# true); otherwise, use a lambda expression to filter out files whose
# extension is not "extension_filter"
if extension_filter is None:
# This looks hackish...there is probably a better way
ext_cond = lambda filename: True
else:
ext_cond = lambda filename: filename.endswith(extension_filter)
filtered_files = []
# Explore the directory tree to get files that contain "name_filter"
# and with extension "extension_filter"
for path, _, files in os.walk(folder):
files.sort()
for file in files:
if name_cond(file) and ext_cond(file):
full_path = os.path.join(path, file)
filtered_files.append(full_path)
return filtered_files
def scannet_loader(self, data_path: str, label_path: str,
color_mean: Optional[list] = [0.,0.,0.],
color_std: Optional[list] = [1.,1.,1.],
seg_classes: str = 'nyu40'):
"""Loads a sample and label image given their path as PIL images
(nyu40 classes).
Args:
data_path (str): The filepath to the image.
label_path (str): The filepath to the ground-truth image.
color_mean (str): R, G, B channel-wise mean
color_std (str): R, G, B channel-wise stddev
seg_classes (str): Palette of classes to load labels for
('nyu40' or 'scannet20')
Returns the image and the label as PIL images.
"""
# Load image.
data = np.array(imageio.imread(data_path))
# Reshape data from H x W x C to C x H x W.
data = np.moveaxis(data, 2, 0)
# Define normalizing transform.
normalize = transforms.Normalize(mean=color_mean, std=color_std)
# Convert image to float and map range from [0, 255] to [0.0, 1.0].
# Then normalize.
data = normalize(torch.Tensor(data.astype(np.float32) / 255.0))
# Load label.
if seg_classes.lower() == 'nyu40':
label = np.array(imageio.imread(label_path)).astype(np.uint8)
elif seg_classes.lower() == 'scannet20':
label = np.array(imageio.imread(label_path)).astype(np.uint8)
# Remap classes from 'nyu40' to 'scannet20'
label = self.nyu40_to_scannet20(label)
return data, label
def scannet_loader_depth(self, data_path: str, depth_path: str,
label_path: str,
color_mean: Optional[list] = [0.,0.,0.],
color_std: Optional[list] = [1.,1.,1.],
seg_classes: Optional[str] = 'nyu40'):
"""Loads a sample and label image given their path as PIL images
(nyu40 classes).
Args:
data_path (str): The filepath to the image.
depth_path (str): The filepath to the depth png.
label_path (str): The filepath to the ground-truth image.
color_mean (list): R, G, B channel-wise mean.
color_std (list): R, G, B channel-wise stddev.
seg_classes (str): Palette of classes to load labels for
('nyu40' or 'scannet20').
Returns:
(PIL.Image): the image
(PIL.Image): the label as PIL images.
"""
# Load image
rgb = np.array(imageio.imread(data_path))
# Reshape rgb from H x W x C to C x H x W
rgb = np.moveaxis(rgb, 2, 0)
# Define normalizing transform
normalize = transforms.Normalize(mean=color_mean, std=color_std)
# Convert image to float and map range from [0, 255] to [0.0, 1.0].
# Then normalize.
rgb = normalize(torch.Tensor(rgb.astype(np.float32) / 255.0))
# Load depth
depth = torch.Tensor(np.array(imageio.imread(depth_path)).astype(
np.float32) / 1000.0)
depth = torch.unsqueeze(depth, 0)
# Concatenate rgb and depth
data = torch.cat((rgb, depth), 0)
# Load label
if seg_classes.lower() == 'nyu40':
label = np.array(imageio.imread(label_path)).astype(np.uint8)
elif seg_classes.lower() == 'scannet20':
label = np.array(imageio.imread(label_path)).astype(np.uint8)
# Remap classes from 'nyu40' to 'scannet20'
label = self.nyu40_to_scannet20(label)
return data, label
def nyu40_to_scannet20(self, label: str):
"""Remap a label image from the 'nyu40' class palette to the
'scannet20' class palette """
# Ignore indices 13, 15, 17, 18, 19, 20, 21, 22, 23, 25, 26. 27. 29.
# 30. 31. 32, 35. 37. 38, 40
# Because, these classes from 'nyu40' are absent from 'scannet20'.
# Our label files are in 'nyu40' format, hence this 'hack'.
# To see detailed class lists visit:
# http://kaldir.vc.in.tum.de/scannet_benchmark/labelids_all.txt
# (for 'nyu40' labels), and
# http://kaldir.vc.in.tum.de/scannet_benchmark/labelids.txt
# (for 'scannet20' labels).
# The remaining labels are then to be mapped onto a contiguous
# ordering in the range [0,20].
# The remapping array comprises tuples (src, tar), where 'src'
# is the 'nyu40' label, and 'tar' is the corresponding target
# 'scannet20' label.
remapping = [(0, 0), (13, 0), (15, 0), (17, 0), (18, 0), (19, 0),
(20, 0), (21, 0), (22, 0), (23, 0), (25, 0), (26, 0),
(27, 0), (29, 0), (30, 0), (31, 0), (32, 0), (35, 0),
(37, 0), (38, 0), (40, 0), (14, 13), (16, 14), (24, 15),
(28, 16), (33, 17), (34, 18), (36, 19), (39, 20)]
for src, tar in remapping:
label[np.where(label==src)] = tar
return label
def create_label_image(output, color_palette):
"""Create a label image, given a network output (each pixel contains
# class index) and a color palette.
Args:
output (np.array, dtype = np.uint8): Output image. Height x Width.
Each pixel contains an integer, corresponding to the class label
for that pixel.
color_palette (OrderedDict): Contains (R, G, B) colors (uint8)
for each class.
"""
label_image = np.zeros((output.shape[0], output.shape[1], 3),
dtype=np.uint8)
for idx, color in enumerate(color_palette):
label_image[output==idx] = color
return label_image