Raw File
# Copyright 2017 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import numpy as np
import tensorflow as tf

import gpflow
from gpflow import misc
from gpflow.test_util import GPflowTestCase, session_tf


class TestPublicMethods(GPflowTestCase):

    @staticmethod
    def run_case(name, equal, not_equal, fn):
        graph = tf.Graph()
        session = tf.get_default_session()

        equal(fn(name))
        equal(fn(name, index='0'))
        equal(fn(name, graph=session.graph))

        not_equal(fn(name, index='1'))
        not_equal(fn(name, graph=graph))
        not_equal(fn(name, graph=graph, index='0'))
        not_equal(fn(name, graph=graph, index='1'))

    def test_tensor_by_name(self):
        with self.test_context():
            name = 'tensor'
            variable = tf.get_variable(name, shape=())
            self.assertTrue(gpflow.misc.is_initializable_tensor(variable))

            def equal(found):
                self.assertFalse(gpflow.misc.is_initializable_tensor(found))
                self.assertTrue(found.name == variable.name)

            def not_equal(found):
                self.assertEqual(found, None)

            fn = gpflow.misc.get_tensor_by_name

            graph = tf.Graph()
            session = tf.get_default_session()
            fake_name = "foo"

            equal(fn(name))
            equal(fn(name, index='0'))
            equal(fn(name, graph=session.graph))

            not_equal(fn(name, index='1'))
            not_equal(fn(name, graph=graph))
            not_equal(fn(name, graph=graph, index='0'))
            not_equal(fn(name, graph=graph, index='1'))
            not_equal(fn(fake_name))
            not_equal(fn(fake_name, graph=graph))

    def test_variable_by_name(self):
        with self.test_context():
            name = 'variable'
            variable = tf.get_variable(name, shape=())
            self.assertTrue(gpflow.misc.is_initializable_tensor(variable))

            def equal(found):
                self.assertTrue(gpflow.misc.is_initializable_tensor(found))
                self.assertEqual(found, variable)

            def not_equal(found):
                self.assertEqual(found, None)

            fn = gpflow.misc.get_variable_by_name

            graph = tf.Graph()
            session = tf.get_default_session()
            fake_name = "foo"

            equal(fn(name))
            equal(fn(name, graph=session.graph))
            not_equal(fn(name, graph=graph))
            not_equal(fn(fake_name))
            not_equal(fn(fake_name, graph=graph))

    def test_valid_param(self):
        with self.test_context():
            name = 'tensor'
            tensor = tf.get_variable(name, shape=())
            self.assertTrue(gpflow.misc.is_valid_param_value(tensor))
            self.assertTrue(gpflow.misc.is_valid_param_value(1.0))
            self.assertTrue(gpflow.misc.is_valid_param_value(1))
            self.assertTrue(gpflow.misc.is_valid_param_value([1.0]))
            self.assertTrue(gpflow.misc.is_valid_param_value([1.0, 1, 1]))
            self.assertTrue(gpflow.misc.is_valid_param_value([1, 1.0, 1]))
            self.assertTrue(gpflow.misc.is_valid_param_value([[1.0], [1]]))
            self.assertTrue(gpflow.misc.is_valid_param_value(np.array(1)))
            self.assertTrue(gpflow.misc.is_valid_param_value(np.array(1.0)))
            self.assertTrue(gpflow.misc.is_valid_param_value(np.array([[1.0], [1]])))
            self.assertTrue(gpflow.misc.is_valid_param_value([[1.0], np.array(1.0)]))
            self.assertTrue(gpflow.misc.is_valid_param_value([np.array(1.0), [1.0]]))

            self.assertFalse(gpflow.misc.is_valid_param_value([]))
            self.assertFalse(gpflow.misc.is_valid_param_value(["", 1.0]))
            self.assertFalse(gpflow.misc.is_valid_param_value([1.0, ""]))
            self.assertFalse(gpflow.misc.is_valid_param_value(["a", 1.0]))
            self.assertFalse(gpflow.misc.is_valid_param_value([1.0, "a"]))
            self.assertFalse(gpflow.misc.is_valid_param_value([1.0, [1.0]]))
            self.assertFalse(gpflow.misc.is_valid_param_value([[1.0], 1.0]))
            self.assertFalse(gpflow.misc.is_valid_param_value(""))
            self.assertFalse(gpflow.misc.is_valid_param_value("1.0"))
            self.assertFalse(gpflow.misc.is_valid_param_value("[1.0]"))
            self.assertFalse(gpflow.misc.is_valid_param_value("0.1"))
            self.assertFalse(gpflow.misc.is_valid_param_value(None))
            self.assertFalse(gpflow.misc.is_valid_param_value(object()))
            self.assertFalse(gpflow.misc.is_valid_param_value(self))

    def test_remove_trainable(self):
        with self.test_context():
            graph = tf.Graph()
            var1 = tf.get_variable('var1', shape=())
            var2 = tf.get_variable('var2', shape=(), trainable=False)

            with self.assertRaises(ValueError):
                gpflow.misc.remove_from_trainables(var1, graph=graph)

            gpflow.misc.remove_from_trainables(var1)
            with self.assertRaises(ValueError):
                gpflow.misc.remove_from_trainables(var1)

            with self.assertRaises(ValueError):
                gpflow.misc.remove_from_trainables(var2)


def test_leading_transpose(session_tf):
    dims = [1, 2, 3, 4]
    a = tf.zeros(dims)
    b = misc.leading_transpose(a, [..., -1, -2])
    c = misc.leading_transpose(a, [-1, ..., -2])
    d = misc.leading_transpose(a, [-1, -2, ...])
    e = misc.leading_transpose(a, [3, 2, ...])
    f = misc.leading_transpose(a, [3, -2, ...])

    assert len(a.shape) == len(b.shape) == len(c.shape) == len(d.shape)
    assert len(a.shape) == len(e.shape) == len(f.shape)
    assert b.shape[-2:] == [4, 3]
    assert c.shape[0] == 4 and c.shape[-1] == 3
    assert d.shape[:2] == [4, 3]
    assert d.shape == e.shape == f.shape


@pytest.mark.xfail(raises=ValueError)
def test_leading_transpose_fail(session_tf):
    dims = [1, 2, 3, 4]
    a = tf.zeros(dims)
    misc.leading_transpose(a, [-1, -2])
back to top