https://github.com/GPflow/GPflow
Tip revision: 3b6200307f2133f3c99930850b7e5106a686d656 authored by uri-granta on 07 February 2024, 13:16:15 UTC
Bump version to 2.9.1 in preparation for release (#2103)
Bump version to 2.9.1 in preparation for release (#2103)
Tip revision: 3b62003
test_multipledispatch.py
import re
import warnings
from typing import Any, Type
import multipledispatch
import pytest
import tensorflow as tf
from _pytest.capture import CaptureFixture
from packaging.version import Version
import gpflow
from gpflow.utilities import Dispatcher
class A1:
pass
class A2(A1):
pass
class B1:
pass
class B2(B1):
pass
def get_test_fn() -> Dispatcher:
test_fn = Dispatcher("test_fn")
@test_fn.register(A1, B1)
def test_a1_b1(x: A1, y: B1) -> str:
return "a1-b1"
@test_fn.register(A2, B1)
def test_a2_b1(x: A2, y: B1) -> str:
return "a2-b1"
@test_fn.register(A1, B2)
def test_a1_b2(x: A1, y: B2) -> str:
return "a1-b2"
return test_fn
def test_our_multipledispatch() -> None:
test_fn = get_test_fn()
assert test_fn(A1(), B1()) == "a1-b1"
assert test_fn(A2(), B1()) == "a2-b1"
assert test_fn(A1(), B2()) == "a1-b2"
# test the ambiguous case:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert test_fn(A2(), B2()) == "a1-b2" # the last definition wins
assert len(w) == 1
assert issubclass(w[0].category, multipledispatch.conflict.AmbiguityWarning)
# test that adding the child-child definition removes ambiguity warning:
@test_fn.register(A2, B2)
def test_a2_b2(x: A2, y: B2) -> str:
return "a2-b2"
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert test_fn(A2(), B2()) == "a2-b2"
assert len(w) == 0
def test_dispatcher__no_match() -> None:
test_fn = get_test_fn()
with pytest.raises(NotImplementedError):
test_fn(3, "foo")
assert None is test_fn.dispatch(int, str)
with pytest.raises(NotImplementedError):
test_fn.dispatch_or_raise(int, str)
@pytest.mark.parametrize(
"Dispatcher, expect_autograph_warning",
[
(
multipledispatch.Dispatcher,
True,
), # checks test is written correctly and captures warnings generated by base Dispatcher
(gpflow.utilities.Dispatcher, False), # no warnings with our custom Dispatcher
],
)
def test_dispatcher_autograph_warnings(
capsys: CaptureFixture[str], Dispatcher: Type[Any], expect_autograph_warning: bool
) -> None:
if Dispatcher is multipledispatch.Dispatcher and (Version(tf.__version__) >= Version("2.3.0")):
pytest.skip("In TensorFlow >= 2.3, multipledispatch.Dispatcher no longer works at all")
tf.autograph.set_verbosity(0, alsologtostdout=True) # to be able to capture it using capsys
test_fn = Dispatcher("test_fn")
# generator would only be invoked when defining for base class...
@test_fn.register(gpflow.inducing_variables.InducingVariables)
def test_iv(x: gpflow.inducing_variables.InducingVariables) -> tf.Tensor:
return tf.reduce_sum(x.Z)
test_fn_compiled = tf.function(test_fn) # with autograph=True by default
# ...but calling using subclass
result = test_fn_compiled(gpflow.inducing_variables.InducingPoints([[1.0, 2.0]]))
assert result.numpy() == 3.0 # expect computation to work either way
captured = capsys.readouterr()
tf_warning = "WARNING:.*Entity .* appears to be a generator function. It will not be converted by AutoGraph."
assert bool(re.match(tf_warning, captured.out)) == expect_autograph_warning