# Copyright 2017 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=W0212
import numpy as np
import tensorflow as tf
import gpflow
from gpflow import settings
from gpflow import session_manager
from gpflow.test_util import GPflowTestCase
class TestSessionConfiguration(GPflowTestCase):
def prepare(self):
with gpflow.defer_build():
return gpflow.models.GPR(
np.ones((1, 1)),
np.ones((1, 1)),
kern=gpflow.kernels.Matern52(1))
def test_option_persistance(self):
'''
Test configuration options are passed to tensorflow session
'''
dop = 3
settings.session.intra_op_parallelism_threads = dop
settings.session.inter_op_parallelism_threads = dop
settings.session.allow_soft_placement = True
session = gpflow.session_manager.get_session()
self.assertTrue(session._config.inter_op_parallelism_threads == dop)
self.assertTrue(isinstance(session._config.inter_op_parallelism_threads, int))
self.assertTrue(session._config.allow_soft_placement)
self.assertTrue(isinstance(session._config.allow_soft_placement, bool))
# m = self.prepare()
# m.compile()
# opt = gpflow.train.ScipyOptimizer()
# opt.minimize(m, maxiter=1)
def test_option_mutability(self):
'''
Test configuration options are passed to tensorflow session
'''
dop = 33
settings.session.intra_op_parallelism_threads = dop
settings.session.inter_op_parallelism_threads = dop
graph = tf.Graph()
tf_session = session_manager.get_session(
graph=graph,
output_file_name=settings.profiling.output_file_name + "_objective",
output_directory=settings.profiling.output_directory,
each_time=settings.profiling.each_time)
self.assertTrue(tf_session._config.intra_op_parallelism_threads == dop)
self.assertTrue(tf_session._config.inter_op_parallelism_threads == dop)
tf_session.close()
# change maximum degree of parallelism
dopOverride = 12
tf_session = session_manager.get_session(
graph=graph,
output_file_name=settings.profiling.output_file_name + "_objective",
output_directory=settings.profiling.output_directory,
each_time=settings.profiling.each_time,
config=tf.ConfigProto(intra_op_parallelism_threads=dopOverride,
inter_op_parallelism_threads=dopOverride))
self.assertTrue(tf_session._config.intra_op_parallelism_threads == dopOverride)
self.assertTrue(tf_session._config.inter_op_parallelism_threads == dopOverride)
tf_session.close()
def test_session_default_graph(self):
tf_session = session_manager.get_session()
self.assertEqual(tf_session.graph, tf.get_default_graph())
tf_session.close()
if __name__ == '__main__':
tf.test.main()