https://github.com/sordonia/hed-qs
Tip revision: b217a94387eca1e37975d7dd770d94981ec7eac1 authored by Alessandro Sordoni on 07 April 2017, 14:07:22 UTC
Update README.md
Update README.md
Tip revision: b217a94
data_iterator.py
__docformat__ = 'restructedtext en'
__authors__ = ("Alessandro Sordoni")
__contact__ = "Alessandro Sordoni <sordonia@iro.umontreal>"
import numpy as np
import theano
import theano.tensor as T
import sys, getopt
import logging
from state import *
from utils import *
from SS_dataset import *
import itertools
import sys
import pickle
import random
import datetime
logger = logging.getLogger(__name__)
def create_padded_batch(state, x, y=None):
mx = state['seqlen']
n = state['bs']
X = numpy.zeros((mx, n), dtype='int32')
Y = numpy.zeros((mx, n), dtype='int32')
Xmask = numpy.zeros((mx, n), dtype='float32')
# Fill X and Xmask
# Keep track of number of predictions and maximum sentence length
num_preds = 0
max_length = 0
for idx in xrange(len(x[0])):
# Insert sequence idx in a column of matrix X
session_length = len(x[0][idx])
# Fiddle-it if it is too long ..
eoq_idx = numpy.where(numpy.array(x[0][idx][:mx - 2]) == state['eoq_sym'])[0]
if mx < session_length:
# .. but if the first query is longer than the session
# we cannot fix it, so we skip it
if not len(eoq_idx):
continue
# fix it
x[0][idx][eoq_idx[-1] + 1] = state['eos_sym']
assert eoq_idx[-1] + 2 < mx
session_length = eoq_idx[-1] + 2
X[:session_length, idx] = x[0][idx][:session_length]
if y:
Y[eoq_idx, idx] = y[0][idx][:len(eoq_idx)]
max_length = max(max_length, session_length)
# Set the number of predictions == sum(Xmask), for cost purposes
num_preds += session_length
# Mark the end of phrase
if session_length < mx:
X[session_length:, idx] = state['eos_sym']
# Initialize Xmask column with ones in all positions that
# were just set in X
Xmask[:session_length, idx] = 1.
assert num_preds == numpy.sum(Xmask)
return {'x': X, 'y': Y, 'x_mask': Xmask, 'num_preds': num_preds, 'max_length': max_length}
def get_batch_iterator(rng, state):
class Iterator(SSIterator):
def __init__(self, *args, **kwargs):
SSIterator.__init__(self, rng, *args, **kwargs)
self.batch_iter = None
def get_homogenous_batch_iter(self):
while True:
k_batches = state['sort_k_batches']
batch_size = state['bs']
data = []
for k in range(k_batches):
batch = SSIterator.next(self)
if batch:
data.append(batch)
if not len(data):
return
sessions = data
if self.has_ranks:
sessions, ranks = zip(*data)
y = numpy.asarray(list(itertools.chain(*ranks)))
x = numpy.asarray(list(itertools.chain(*sessions)))
lens = numpy.asarray([map(len, x)])
order = numpy.argsort(lens.max(axis=0))
for k in range(len(data)):
indices = order[k * batch_size:(k + 1) * batch_size]
if self.has_ranks:
batch = create_padded_batch(state, [x[indices]], [y[indices]])
else:
batch = create_padded_batch(state, [x[indices]])
if batch:
yield batch
def start(self):
SSIterator.start(self)
self.batch_iter = None
def next(self):
if not self.batch_iter:
self.batch_iter = self.get_homogenous_batch_iter()
try:
batch = next(self.batch_iter)
except StopIteration:
return None
return batch
train_data = Iterator(
batch_size=int(state['bs']),
session_file=state['train_session'],
rank_file=state.get('train_rank', None),
queue_size=100,
use_infinite_loop=True,
max_len=state['seqlen'])
valid_data = Iterator(
batch_size=int(state['bs']),
session_file=state['valid_session'],
rank_file=state.get('valid_rank', None),
use_infinite_loop=False,
queue_size=100,
max_len=state['seqlen'])
return train_data, valid_data