minibatch.py
# Copyright 2016 James Hensman, Valentine Svensson, alexggmatthews, Mark van der Wilk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from .param import DataHolder
class IndexManager(object):
"""
Base clase for methods of batch indexing data.
rng is an instance of np.random.RandomState, defaults to seed 0.
"""
def __init__(self, minibatch_size, total_points, rng=None):
self.minibatch_size = minibatch_size
self.total_points = total_points
self.rng = rng or np.random.RandomState(0)
def nextIndices(self):
raise NotImplementedError
class ReplacementSampling(IndexManager):
def nextIndices(self):
return self.rng.randint(self.total_points,
size=self.minibatch_size)
class NoReplacementSampling(IndexManager):
def __init__(self, minibatch_size, total_points, rng=None):
# Can't sample without replacement is minibatch_size is larger
# than total_points
assert(minibatch_size <= total_points)
IndexManager.__init__(self, minibatch_size, total_points, rng)
def nextIndices(self):
permutation = self.rng.permutation(self.total_points)
return permutation[:self.minibatch_size]
class SequenceIndices(IndexManager):
"""
A class that maintains the state necessary to manage
sequential indexing of data holders.
"""
def __init__(self, minibatch_size, total_points, rng=None):
self.counter = 0
IndexManager.__init__(self, minibatch_size, total_points, rng)
def nextIndices(self):
"""
Written so that if total_points
changes this will still work
"""
firstIndex = self.counter
lastIndex = self.counter + self.minibatch_size
self.counter = lastIndex % self.total_points
return np.arange(firstIndex, lastIndex) % self.total_points
class MinibatchData(DataHolder):
"""
A special DataHolder class which feeds a minibatch
to tensorflow via update_feed_dict().
"""
# List of valid specifiers for generation methods.
_generation_methods = ['replace', 'noreplace', 'sequential']
def __init__(self, array, minibatch_size, rng=None, batch_manager=None):
"""
array is a numpy array of data.
minibatch_size (int) is the size of the minibatch
batch_manager specified data sampling scheme and is a subclass
of IndexManager.
Note: you may want to randomize the order of the data
if using sequential generation.
"""
DataHolder.__init__(self, array, on_shape_change='pass')
total_points = self._array.shape[0]
self.parseGenerationMethod(batch_manager,
minibatch_size,
total_points,
rng)
def parseGenerationMethod(self,
input_batch_manager,
minibatch_size,
total_points,
rng):
# Logic for default behaviour.
# When minibatch_size is a small fraction of total_point
# ReplacementSampling should give similar results to
# NoReplacementSampling and the former can be much faster.
if input_batch_manager is None:
fraction = float(minibatch_size) / float(total_points)
if fraction < 0.5:
self.index_manager = ReplacementSampling(minibatch_size,
total_points,
rng)
else:
self.index_manager = NoReplacementSampling(minibatch_size,
total_points,
rng)
else: # Explicitly specified behaviour.
if input_batch_manager.__class__ not in IndexManager.__subclasses__():
raise NotImplementedError
self.index_manager = input_batch_manager
def update_feed_dict(self, key_dict, feed_dict):
next_indices = self.index_manager.nextIndices()
feed_dict[key_dict[self]] = self._array[next_indices]