Revision e9275c2749d590998082da0544957d33c28d59e5 authored by Mark van der Wilk on 03 May 2017, 17:08:06 UTC, committed by GitHub on 03 May 2017, 17:08:06 UTC
* Add regression test for NaNs in gradient * Fix NaN in gradient if cos_theta is close to one * Use jitter close to machine epsilon
1 parent 5190ada
test_minibatch_data.py
# Copyright 2016 alexggmatthews
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
import numpy as np
import unittest
from GPflow.minibatch import SequenceIndices, MinibatchData
from GPflow.minibatch import ReplacementSampling, NoReplacementSampling
class TestSequentialManager(unittest.TestCase):
def setUp(self):
tf.reset_default_graph()
def testA(self):
minibatch_size = 3
total_points = 5
sequenceManager = SequenceIndices(minibatch_size,total_points)
indecesA = sequenceManager.nextIndices()
self.assertTrue((indecesA==np.arange(0,minibatch_size)).all())
indecesB = sequenceManager.nextIndices()
self.assertTrue((indecesB==np.array([3,4,0 ])).all())
def testB(self):
minibatch_size = 5
total_points = 2
sequenceManager = SequenceIndices(minibatch_size,total_points)
indecesA = sequenceManager.nextIndices()
targetIndicesA = np.array([0,1,0,1,0])
self.assertTrue((indecesA==targetIndicesA).all())
indecesB = sequenceManager.nextIndices()
targetIndicesB = np.array([1,0,1,0,1])
self.assertTrue((indecesB==targetIndicesB).all())
class TestRandomIndexManagers(unittest.TestCase):
def setUp(self):
tf.reset_default_graph()
def checkUniformDist(self, indeces, nChoices):
fTotalPoints = float(len(indeces))
tolerance = 1e-2
for possibleIndex in range(nChoices):
empirical = indeces.count(possibleIndex) / fTotalPoints
error = np.abs(empirical - 1./nChoices)
if error > tolerance:
return False
return True
def testReplacement(self):
minibatch_size = 1000
data_size = 3
tolerance = 1e-2
rs = ReplacementSampling(minibatch_size,data_size)
indeces = rs.nextIndices().tolist()
self.assertTrue(self.checkUniformDist(indeces,data_size))
def testNoReplacement(self):
mini_size_err = 5
data_size_err = 3
constructor = lambda : NoReplacementSampling(mini_size_err,
data_size_err)
self.assertRaises(AssertionError,constructor)
mini_size = 3
data_size = 3
nrs = NoReplacementSampling(mini_size,data_size)
indeces = np.sort(nrs.nextIndices()).tolist()
self.assertEqual(indeces,[0,1,2])
one = 1
nrsb = NoReplacementSampling(one,data_size)
indecesOverall = []
for repeatIndex in range(3000):
indeces = nrsb.nextIndices().tolist()
indecesOverall = indecesOverall + indeces
self.assertTrue(self.checkUniformDist(indecesOverall,data_size))
class TestMinibatchData(unittest.TestCase):
def setUp(self):
tf.reset_default_graph()
self.nDataPoints = 10
self.minibatch_size = 4
self.dummyArray = np.atleast_2d(np.arange(self.nDataPoints)).T
def testA(self):
constructor = lambda : MinibatchData(self.dummyArray,
self.minibatch_size,
rng=None,
batch_manager=[])
self.assertRaises(NotImplementedError,constructor)
def testB(self):
sm = SequenceIndices(self.minibatch_size, self.nDataPoints)
md = MinibatchData(self.dummyArray,
self.minibatch_size,
rng=None,
batch_manager=sm)
output_string = 'test_out'
key_dict = {md: output_string}
feed_dict = {}
md.update_feed_dict( key_dict, feed_dict )
test_out_array = np.atleast_2d(np.arange(self.minibatch_size)).T
test_dict = {output_string : test_out_array }
self.assertEqual(feed_dict.keys(),test_dict.keys())
self.assertTrue((feed_dict[output_string]==test_out_array).all())
def testC(self):
#Test the manager selection logic.
md = MinibatchData(self.dummyArray,
self.minibatch_size,
rng=None)
self.assertEqual(md.index_manager.__class__, ReplacementSampling)
md = MinibatchData(self.dummyArray,
self.minibatch_size*2,
rng=None)
self.assertEqual(md.index_manager.__class__, NoReplacementSampling)
if __name__ == "__main__":
unittest.main()
Computing file changes ...