https://github.com/liuxinhai/Point2Sequence
Revision 849b0ab0b5055185712c702749be53973facaaa4 authored by liuxinhai on 13 March 2019, 13:22:01 UTC, committed by liuxinhai on 13 March 2019, 13:22:01 UTC
1 parent 37b8b35
Tip revision: 849b0ab0b5055185712c702749be53973facaaa4 authored by liuxinhai on 13 March 2019, 13:22:01 UTC
grouping
grouping
Tip revision: 849b0ab
modelnet_h5_dataset.py
'''
ModelNet dataset. Support ModelNet40, XYZ channels. Up to 2048 points.
Faster IO than ModelNetDataset in the first epoch.
'''
import os
import sys
import numpy as np
import h5py
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
import provider
# Download dataset for point cloud classification
DATA_DIR = os.path.join(ROOT_DIR, 'data')
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
zipfile = os.path.basename(www)
os.system('wget %s; unzip %s' % (www, zipfile))
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
os.system('rm %s' % (zipfile))
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def getDataFiles(list_filename):
return [line.rstrip() for line in open(list_filename)]
def load_h5(h5_filename):
f = h5py.File(h5_filename)
data = f['data'][:]
label = f['label'][:]
return (data, label)
def loadDataFile(filename):
return load_h5(filename)
class ModelNetH5Dataset(object):
def __init__(self, list_filename, batch_size = 32, npoints = 1024, shuffle=True):
self.list_filename = list_filename
self.batch_size = batch_size
self.npoints = npoints
self.shuffle = shuffle
self.h5_files = getDataFiles(self.list_filename)
self.reset()
def reset(self):
''' reset order of h5 files '''
self.file_idxs = np.arange(0, len(self.h5_files))
if self.shuffle: np.random.shuffle(self.file_idxs)
self.current_data = None
self.current_label = None
self.current_file_idx = 0
self.batch_idx = 0
def _augment_batch_data(self, batch_data):
rotated_data = provider.rotate_point_cloud(batch_data)
rotated_data = provider.rotate_perturbation_point_cloud(rotated_data)
jittered_data = provider.random_scale_point_cloud(rotated_data[:,:,0:3])
jittered_data = provider.shift_point_cloud(jittered_data)
jittered_data = provider.jitter_point_cloud(jittered_data)
rotated_data[:,:,0:3] = jittered_data
return provider.shuffle_points(rotated_data)
def _get_data_filename(self):
return self.h5_files[self.file_idxs[self.current_file_idx]]
def _load_data_file(self, filename):
self.current_data,self.current_label = load_h5(filename)
self.current_label = np.squeeze(self.current_label)
self.batch_idx = 0
if self.shuffle:
self.current_data, self.current_label, _ = shuffle_data(self.current_data,self.current_label)
def _has_next_batch_in_file(self):
return self.batch_idx*self.batch_size < self.current_data.shape[0]
def num_channel(self):
return 3
def has_next_batch(self):
# TODO: add backend thread to load data
if (self.current_data is None) or (not self._has_next_batch_in_file()):
if self.current_file_idx >= len(self.h5_files):
return False
self._load_data_file(self._get_data_filename())
self.batch_idx=0
self.current_file_idx += 1
return self._has_next_batch_in_file()
def next_batch(self, augment=False):
''' returned dimension may be smaller than self.batch_size '''
start_idx = self.batch_idx * self.batch_size
end_idx = min((self.batch_idx+1) * self.batch_size, self.current_data.shape[0])
bsize = end_idx - start_idx
batch_label = np.zeros((bsize), dtype=np.int32)
data_batch = self.current_data[start_idx:end_idx, 0:self.npoints, :].copy()
label_batch = self.current_label[start_idx:end_idx].copy()
self.batch_idx += 1
if augment: data_batch = self._augment_batch_data(data_batch)
return data_batch, label_batch
if __name__=='__main__':
d = ModelNetH5Dataset('data/modelnet40_ply_hdf5_2048/train_files.txt')
print(d.shuffle)
print(d.has_next_batch())
ps_batch, cls_batch = d.next_batch(True)
print(ps_batch.shape)
print(cls_batch.shape)
Computing file changes ...