https://github.com/brownvc/deep-synth
Tip revision: b800e11290b763b58e7d3b30329769a7b77cd12a authored by kwang-ether on 14 June 2019, 23:53:57 UTC
remove csv
remove csv
Tip revision: b800e11
continue_dataset.py
from torch.utils import data
from data import ObjectCategories, RenderedScene, RenderedComposite
import random
import math
import torch
import _pickle as pickle
class ShouldContinueDataset():
"""
Dataset for training/testing the "should continue" network
"""
def __init__(self, data_root_dir, data_dir, scene_indices=(0,4000), num_per_epoch=1, complete_prob=0.5, seed=None, ablation=None):
"""
Parameters
----------
data_root_dir (String): root dir where all data lives
data_dir (String): directory where this dataset lives (relative to data_root_dir)
scene_indices (tuple[int, int]): list of indices of scenes (in data_dir) that are considered part of this set
num_per_epoch (int): number of random variants of each scene that will be used per training epoch
complete_prob (float): probability of sampling the complete scene, as opposed to some incomplete variant of it
"""
self.data_root_dir = data_root_dir
#self.data_dir = data_root_dir + '/' + data_dir
self.data_dir = data_dir
self.scene_indices = scene_indices
self.num_per_epoch = num_per_epoch
self.complete_prob = complete_prob
# Load up the map between SUNCG model IDs and category names
#self.category_map = ObjectCategories(data_root_dir + '/suncg_data/ModelCategoryMapping.csv')
# Also load up the list of coarse categories used in this particular dataset
#self.categories = self.get_coarse_categories()
# Build a reverse map from category to index
#self.cat_to_index = {self.categories[i]:i for i in range(len(self.categories))}
self.seed = seed
self.ablation = ablation
def __len__(self):
return (self.scene_indices[1]-self.scene_indices[0]) * self.num_per_epoch
def __getitem__(self, index):
if self.seed:
random.seed(self.seed)
i = int(index+self.scene_indices[0] / self.num_per_epoch)
scene = RenderedScene(i, self.data_dir, self.data_root_dir)
composite = scene.create_composite()
num_categories = len(scene.categories)
existing_categories = torch.zeros(num_categories)
# Flip a coin for whether we're going remove objects or treat this as a complete scene
is_complete = random.random() < self.complete_prob
if not is_complete:
# If we decide to remove objects, then remove a random number of them
num_objects = random.randint(0, len(scene.object_nodes) - 1)
else:
num_objects = len(scene.object_nodes)
for i in range(num_objects):
node = scene.object_nodes[i]
composite.add_node(node)
existing_categories[node["category"]] += 1
inputs = composite.get_composite(num_extra_channels=0, ablation=self.ablation)
# Output is a boolean for "should we continue adding objects?"
output = not is_complete
return inputs, output, existing_categories
