https://github.com/GPflow/GPflow
Raw File
Tip revision: 59e41536cd612555ec8f1039d09c3d76f5264cab authored by alexggmatthews on 17 February 2017, 10:43:52 UTC
Incorporating sphinx rtd theme in codebase. MIT license.
Tip revision: 59e4153
minibatch.py
# Copyright 2016 James Hensman, Valentine Svensson, alexggmatthews, Mark van der Wilk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from .param import DataHolder


class IndexManager(object):
    """
    Base clase for methods of batch indexing data.
    rng is an instance of np.random.RandomState, defaults to seed 0.
    """
    def __init__(self, minibatch_size, total_points, rng=None):
        self.minibatch_size = minibatch_size
        self.total_points = total_points
        self.rng = rng or np.random.RandomState(0)

    def nextIndices(self):
        raise NotImplementedError


class ReplacementSampling(IndexManager):
    def nextIndices(self):
        return self.rng.randint(self.total_points,
                                size=self.minibatch_size)


class NoReplacementSampling(IndexManager):
    def __init__(self, minibatch_size, total_points, rng=None):
        # Can't sample without replacement is minibatch_size is larger
        # than total_points
        assert(minibatch_size <= total_points)
        IndexManager.__init__(self, minibatch_size, total_points, rng)

    def nextIndices(self):
        permutation = self.rng.permutation(self.total_points)
        return permutation[:self.minibatch_size]


class SequenceIndices(IndexManager):
    """
    A class that maintains the state necessary to manage
    sequential indexing of data holders.
    """
    def __init__(self, minibatch_size, total_points, rng=None):
        self.counter = 0
        IndexManager.__init__(self, minibatch_size, total_points, rng)

    def nextIndices(self):
        """
        Written so that if total_points
        changes this will still work
        """
        firstIndex = self.counter
        lastIndex = self.counter + self.minibatch_size
        self.counter = lastIndex % self.total_points
        return np.arange(firstIndex, lastIndex) % self.total_points


class MinibatchData(DataHolder):
    """
    A special DataHolder class which feeds a minibatch
    to tensorflow via update_feed_dict().
    """

    # List of valid specifiers for generation methods.
    _generation_methods = ['replace', 'noreplace', 'sequential']

    def __init__(self, array, minibatch_size, rng=None, batch_manager=None):
        """
        array is a numpy array of data.
        minibatch_size (int) is the size of the minibatch

        batch_manager specified data sampling scheme and is a subclass
        of IndexManager.

        Note: you may want to randomize the order of the data
        if using sequential generation.
        """
        DataHolder.__init__(self, array, on_shape_change='pass')
        total_points = self._array.shape[0]
        self.parseGenerationMethod(batch_manager,
                                   minibatch_size,
                                   total_points,
                                   rng)

    def parseGenerationMethod(self,
                              input_batch_manager,
                              minibatch_size,
                              total_points,
                              rng):
        # Logic for default behaviour.
        # When minibatch_size is a small fraction of total_point
        # ReplacementSampling should give similar results to
        # NoReplacementSampling and the former can be much faster.
        if input_batch_manager is None:
            fraction = float(minibatch_size) / float(total_points)
            if fraction < 0.5:
                self.index_manager = ReplacementSampling(minibatch_size,
                                                         total_points,
                                                         rng)
            else:
                self.index_manager = NoReplacementSampling(minibatch_size,
                                                           total_points,
                                                           rng)
        else:  # Explicitly specified behaviour.
            if input_batch_manager.__class__ not in IndexManager.__subclasses__():
                raise NotImplementedError
            self.index_manager = input_batch_manager

    def update_feed_dict(self, key_dict, feed_dict):
        next_indices = self.index_manager.nextIndices()
        feed_dict[key_dict[self]] = self._array[next_indices]
back to top