Raw File
# Copyright 2017 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# pylint: disable=W0212

import numpy as np
import tensorflow as tf

import gpflow
from gpflow import settings
from gpflow import session_manager
from gpflow.test_util import GPflowTestCase


class TestSessionConfiguration(GPflowTestCase):

    def prepare(self):
        with gpflow.defer_build():
            return gpflow.models.GPR(
                np.ones((1, 1)),
                np.ones((1, 1)),
                kern=gpflow.kernels.Matern52(1))

    def test_option_persistance(self):
        '''
        Test configuration options are passed to tensorflow session
        '''
        dop = 3
        settings.session.intra_op_parallelism_threads = dop
        settings.session.inter_op_parallelism_threads = dop
        settings.session.allow_soft_placement = True
        session = gpflow.session_manager.get_session()
        self.assertTrue(session._config.inter_op_parallelism_threads == dop)
        self.assertTrue(isinstance(session._config.inter_op_parallelism_threads, int))
        self.assertTrue(session._config.allow_soft_placement)
        self.assertTrue(isinstance(session._config.allow_soft_placement, bool))
        # m = self.prepare()
        # m.compile()
        # opt = gpflow.train.ScipyOptimizer()
        # opt.minimize(m, maxiter=1)

    def test_option_mutability(self):
        '''
        Test configuration options are passed to tensorflow session
        '''
        dop = 33
        settings.session.intra_op_parallelism_threads = dop
        settings.session.inter_op_parallelism_threads = dop
        graph = tf.Graph()
        tf_session = session_manager.get_session(
            graph=graph,
            output_file_name=settings.profiling.output_file_name + "_objective",
            output_directory=settings.profiling.output_directory,
            each_time=settings.profiling.each_time)
        self.assertTrue(tf_session._config.intra_op_parallelism_threads == dop)
        self.assertTrue(tf_session._config.inter_op_parallelism_threads == dop)
        tf_session.close()

        # change maximum degree of parallelism
        dopOverride = 12
        tf_session = session_manager.get_session(
            graph=graph,
            output_file_name=settings.profiling.output_file_name + "_objective",
            output_directory=settings.profiling.output_directory,
            each_time=settings.profiling.each_time,
            config=tf.ConfigProto(intra_op_parallelism_threads=dopOverride,
                                  inter_op_parallelism_threads=dopOverride))
        self.assertTrue(tf_session._config.intra_op_parallelism_threads == dopOverride)
        self.assertTrue(tf_session._config.inter_op_parallelism_threads == dopOverride)
        tf_session.close()

    def test_session_default_graph(self):
        tf_session = session_manager.get_session()
        self.assertEqual(tf_session.graph, tf.get_default_graph())
        tf_session.close()


if __name__ == '__main__':
    tf.test.main()
back to top