https://github.com/GPflow/GPflow
Tip revision: 964cfeeb98d02f9a6356e00beb59819aa7414158 authored by Nicolas Durrande on 11 March 2020, 13:24:49 UTC
Update gpflow/kernels/stationaries.py
Update gpflow/kernels/stationaries.py
Tip revision: 964cfee
test_config.py
# Copyright 2018 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
import numpy as np
import pytest
import tensorflow as tf
import tensorflow_probability as tfp
import gpflow
from gpflow.config import (
default_float,
default_int,
default_jitter,
default_positive_bijector,
default_summary_fmt,
set_default_float,
set_default_int,
set_default_jitter,
set_default_positive_bijector,
set_default_summary_fmt,
)
from gpflow.utilities import to_default_float, to_default_int
_env_values = [
("int", "int16", np.int16),
("int", "int64", np.int64),
("float", "float16", np.float16),
("float", "float32", np.float32),
("positive_bijector", "exp", "exp"),
("positive_bijector", "softplus", "softplus"),
("summary_fmt", "simple", "simple"),
("positive_minimum", "1e-3", 1e-3),
("jitter", "1e-2", 1e-2),
]
@pytest.mark.parametrize("attr_name, value, expected_value", _env_values)
def test_env_variables(attr_name, value, expected_value):
env_name = f"GPFLOW_{attr_name.upper()}"
with mock.patch.dict("os.environ", {env_name: value}):
assert os.environ[env_name] == value
config = gpflow.config.Config()
assert getattr(config, attr_name) == expected_value
@pytest.mark.parametrize("attr_name", set(list(zip(*_env_values))[0]))
def test_env_variables_failures(attr_name):
if attr_name == "summary_fmt":
pytest.skip("The `summary_fmt` validation cannot be performed.")
env_name = f"GPFLOW_{attr_name.upper()}"
with mock.patch.dict("os.environ", {env_name: "garbage"}):
with pytest.raises(TypeError):
gpflow.config.Config()
@pytest.mark.parametrize(
"getter, setter, valid_type_1, valid_type_2",
[
(default_int, set_default_int, tf.int64, np.int32),
(default_float, set_default_float, tf.float32, np.float64),
],
)
def test_dtype_setting(getter, setter, valid_type_1, valid_type_2):
if valid_type_1 == valid_type_2:
raise ValueError("cannot test config setting/getting when both types are equal")
setter(valid_type_1)
assert getter() == valid_type_1
setter(valid_type_2)
assert getter() == valid_type_2
@pytest.mark.parametrize(
"setter, invalid_type",
[
(set_default_int, str),
(set_default_int, np.float64),
(set_default_float, list),
(set_default_float, tf.int32),
],
)
def test_dtype_errorcheck(setter, invalid_type):
with pytest.raises(TypeError):
setter(invalid_type)
def test_jitter_setting():
set_default_jitter(1e-3)
assert default_jitter() == 1e-3
set_default_jitter(1e-6)
assert default_jitter() == 1e-6
def test_jitter_errorcheck():
with pytest.raises(TypeError):
set_default_jitter("not a float")
with pytest.raises(ValueError):
set_default_jitter(-1e-10)
@pytest.mark.parametrize(
"value, error_msg",
[
("Unknown", r"`unknown` not in set of valid bijectors: \['exp', 'softplus'\]"),
(1.0, r"`1.0` not in set of valid bijectors: \['exp', 'softplus'\]"),
],
)
def test_positive_bijector_error(value, error_msg):
with pytest.raises(ValueError, match=error_msg):
set_default_positive_bijector(value)
@pytest.mark.parametrize("value", ["exp", "SoftPlus"])
def test_positive_bijector_setting(value):
set_default_positive_bijector(value)
assert default_positive_bijector() == value.lower()
def test_default_summary_fmt_setting():
set_default_summary_fmt("html")
assert default_summary_fmt() == "html"
set_default_summary_fmt(None)
assert default_summary_fmt() is None
def test_default_summary_fmt_errorcheck():
with pytest.raises(ValueError):
set_default_summary_fmt("this_format_definitely_does_not_exist")
@pytest.mark.parametrize(
"setter, getter, converter, dtype, value",
[
(set_default_int, default_int, to_default_int, np.int32, 3),
(set_default_int, default_int, to_default_int, tf.int32, 3),
(set_default_int, default_int, to_default_int, tf.int64, [3, 1, 4, 1, 5, 9]),
(set_default_int, default_int, to_default_int, np.int64, [3, 1, 4, 1, 5, 9]),
(set_default_float, default_float, to_default_float, np.float32, 3.14159),
(
set_default_float,
default_float,
to_default_float,
tf.float32,
[3.14159, 3.14159, 3.14159],
),
(
set_default_float,
default_float,
to_default_float,
np.float64,
[3.14159, 3.14159, 3.14159],
),
(
set_default_float,
default_float,
to_default_float,
tf.float64,
[3.14159, 3.14159, 3.14159],
),
],
)
def test_native_to_default_dtype(setter, getter, converter, dtype, value):
with gpflow.config.as_context():
setter(dtype)
assert converter(value).dtype == dtype
assert converter(value).dtype == getter()