https://github.com/GPflow/GPflow
Raw File
Tip revision: 3b6200307f2133f3c99930850b7e5106a686d656 authored by uri-granta on 07 February 2024, 13:16:15 UTC
Bump version to 2.9.1 in preparation for release (#2103)
Tip revision: 3b62003
test_multipledispatch.py
import re
import warnings
from typing import Any, Type

import multipledispatch
import pytest
import tensorflow as tf
from _pytest.capture import CaptureFixture
from packaging.version import Version

import gpflow
from gpflow.utilities import Dispatcher


class A1:
    pass


class A2(A1):
    pass


class B1:
    pass


class B2(B1):
    pass


def get_test_fn() -> Dispatcher:
    test_fn = Dispatcher("test_fn")

    @test_fn.register(A1, B1)
    def test_a1_b1(x: A1, y: B1) -> str:
        return "a1-b1"

    @test_fn.register(A2, B1)
    def test_a2_b1(x: A2, y: B1) -> str:
        return "a2-b1"

    @test_fn.register(A1, B2)
    def test_a1_b2(x: A1, y: B2) -> str:
        return "a1-b2"

    return test_fn


def test_our_multipledispatch() -> None:
    test_fn = get_test_fn()

    assert test_fn(A1(), B1()) == "a1-b1"
    assert test_fn(A2(), B1()) == "a2-b1"
    assert test_fn(A1(), B2()) == "a1-b2"

    # test the ambiguous case:

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")

        assert test_fn(A2(), B2()) == "a1-b2"  # the last definition wins

        assert len(w) == 1
        assert issubclass(w[0].category, multipledispatch.conflict.AmbiguityWarning)

    # test that adding the child-child definition removes ambiguity warning:

    @test_fn.register(A2, B2)
    def test_a2_b2(x: A2, y: B2) -> str:
        return "a2-b2"

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")

        assert test_fn(A2(), B2()) == "a2-b2"

        assert len(w) == 0


def test_dispatcher__no_match() -> None:
    test_fn = get_test_fn()

    with pytest.raises(NotImplementedError):
        test_fn(3, "foo")

    assert None is test_fn.dispatch(int, str)

    with pytest.raises(NotImplementedError):
        test_fn.dispatch_or_raise(int, str)


@pytest.mark.parametrize(
    "Dispatcher, expect_autograph_warning",
    [
        (
            multipledispatch.Dispatcher,
            True,
        ),  # checks test is written correctly and captures warnings generated by base Dispatcher
        (gpflow.utilities.Dispatcher, False),  # no warnings with our custom Dispatcher
    ],
)
def test_dispatcher_autograph_warnings(
    capsys: CaptureFixture[str], Dispatcher: Type[Any], expect_autograph_warning: bool
) -> None:
    if Dispatcher is multipledispatch.Dispatcher and (Version(tf.__version__) >= Version("2.3.0")):
        pytest.skip("In TensorFlow >= 2.3, multipledispatch.Dispatcher no longer works at all")

    tf.autograph.set_verbosity(0, alsologtostdout=True)  # to be able to capture it using capsys

    test_fn = Dispatcher("test_fn")

    # generator would only be invoked when defining for base class...
    @test_fn.register(gpflow.inducing_variables.InducingVariables)
    def test_iv(x: gpflow.inducing_variables.InducingVariables) -> tf.Tensor:
        return tf.reduce_sum(x.Z)

    test_fn_compiled = tf.function(test_fn)  # with autograph=True by default

    # ...but calling using subclass
    result = test_fn_compiled(gpflow.inducing_variables.InducingPoints([[1.0, 2.0]]))
    assert result.numpy() == 3.0  # expect computation to work either way

    captured = capsys.readouterr()

    tf_warning = "WARNING:.*Entity .* appears to be a generator function. It will not be converted by AutoGraph."
    assert bool(re.match(tf_warning, captured.out)) == expect_autograph_warning
back to top