Revision f4ac4c5a52603afdd12c258958456bd170df7092 authored by ST John on 24 March 2022, 08:32:27 UTC, committed by ST John on 24 March 2022, 08:32:27 UTC
1 parent cc7ed07
Raw File
quadratures.py
# Copyright 2017-2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, Type, Union, cast

import numpy as np
import tensorflow as tf

from .. import kernels
from .. import mean_functions as mfn
from ..base import TensorType
from ..covariances import Kuf
from ..inducing_variables import InducingVariables
from ..probability_distributions import DiagonalGaussian, Gaussian, MarkovGaussian
from ..quadrature import mvnquad
from . import dispatch
from .expectations import ExpectationObject, PackedExpectationObject, quadrature_expectation

register = dispatch.quadrature_expectation.register


NoneType: Type[None] = type(None)
EllipsisType = Any


def get_eval_func(
    obj: ExpectationObject,
    inducing_variable: Optional[InducingVariables],
    slice: Union[slice, EllipsisType, None] = None,
) -> Callable[[TensorType], tf.Tensor]:
    """
    Return the function of interest (kernel or mean) for the expectation
    depending on the type of :obj: and whether any inducing are given
    """

    slice = ... if slice is None else slice
    if inducing_variable is not None:
        # kernel + inducing_variable combination
        if not isinstance(inducing_variable, InducingVariables) or not isinstance(
            obj, kernels.Kernel
        ):
            raise TypeError("If `inducing_variable` is supplied, `obj` must be a kernel.")
        return lambda x: tf.transpose(Kuf(inducing_variable, obj, x))[slice]
    elif isinstance(obj, mfn.MeanFunction):
        return lambda x: obj(x)[slice]  # type: ignore
    elif isinstance(obj, kernels.Kernel):
        return lambda x: obj(x, full_cov=False)  # type: ignore

    raise NotImplementedError()


@dispatch.quadrature_expectation.register(
    (Gaussian, DiagonalGaussian),
    object,
    (InducingVariables, NoneType),
    object,
    (InducingVariables, NoneType),
)
def _quadrature_expectation_gaussian(
    p: Union[Gaussian, DiagonalGaussian],
    obj1: ExpectationObject,
    inducing_variable1: Optional[InducingVariables],
    obj2: ExpectationObject,
    inducing_variable2: Optional[InducingVariables],
    nghp: Optional[int] = None,
) -> tf.Tensor:
    """
    General handling of quadrature expectations for Gaussians and DiagonalGaussians
    Fallback method for missing analytic expectations
    """
    nghp = 100 if nghp is None else nghp

    # logger.warning(
    #     "Quadrature is used to calculate the expectation. This means that "
    #     "an analytical implementations is not available for the given combination."
    # )

    if obj1 is None:
        raise NotImplementedError("First object cannot be None.")

    if not isinstance(p, DiagonalGaussian):
        cov = p.cov
    else:
        if (
            isinstance(obj1, kernels.Kernel)
            and isinstance(obj2, kernels.Kernel)
            and obj1.on_separate_dims(obj2)
        ):  # no joint expectations required
            eKxz1 = quadrature_expectation(
                p, cast(PackedExpectationObject, (obj1, inducing_variable1)), nghp=nghp
            )
            eKxz2 = quadrature_expectation(
                p, cast(PackedExpectationObject, (obj2, inducing_variable2)), nghp=nghp
            )
            return eKxz1[:, :, None] * eKxz2[:, None, :]
        cov = tf.linalg.diag(p.cov)

    if obj2 is None:

        def eval_func(x: TensorType) -> tf.Tensor:
            fn = get_eval_func(obj1, inducing_variable1)
            return fn(x)

    else:

        def eval_func(x: TensorType) -> tf.Tensor:
            fn1 = get_eval_func(obj1, inducing_variable1, np.s_[:, :, None])
            fn2 = get_eval_func(obj2, inducing_variable2, np.s_[:, None, :])
            return fn1(x) * fn2(x)

    return mvnquad(eval_func, p.mu, cov, nghp)


@dispatch.quadrature_expectation.register(
    MarkovGaussian, object, (InducingVariables, NoneType), object, (InducingVariables, NoneType)
)
def _quadrature_expectation_markov(
    p: MarkovGaussian,
    obj1: ExpectationObject,
    inducing_variable1: Optional[InducingVariables],
    obj2: ExpectationObject,
    inducing_variable2: Optional[InducingVariables],
    nghp: Optional[int] = None,
) -> tf.Tensor:
    """
    Handling of quadrature expectations for Markov Gaussians (useful for time series)
    Fallback method for missing analytic expectations wrt Markov Gaussians
    Nota Bene: obj1 is always associated with x_n, whereas obj2 always with x_{n+1}
               if one requires e.g. <x_{n+1} K_{x_n, Z}>_p(x_{n:n+1}), compute the
               transpose and then transpose the result of the expectation
    """
    nghp = 40 if nghp is None else nghp

    # logger.warning(
    #     "Quadrature is used to calculate the expectation. This means that "
    #     "an analytical implementations is not available for the given combination."
    # )

    if obj2 is None:

        def eval_func(x: TensorType) -> tf.Tensor:
            return get_eval_func(obj1, inducing_variable1)(x)

        mu, cov = p.mu[:-1], p.cov[0, :-1]  # cross covariances are not needed
    elif obj1 is None:

        def eval_func(x: TensorType) -> tf.Tensor:
            return get_eval_func(obj2, inducing_variable2)(x)

        mu, cov = p.mu[1:], p.cov[0, 1:]  # cross covariances are not needed
    else:

        def eval_func(x: TensorType) -> tf.Tensor:
            x1 = tf.split(x, 2, 1)[0]
            x2 = tf.split(x, 2, 1)[1]
            res1 = get_eval_func(obj1, inducing_variable1, np.s_[:, :, None])(x1)
            res2 = get_eval_func(obj2, inducing_variable2, np.s_[:, None, :])(x2)
            return res1 * res2

        mu = tf.concat((p.mu[:-1, :], p.mu[1:, :]), 1)  # Nx2D
        cov_top = tf.concat((p.cov[0, :-1, :, :], p.cov[1, :-1, :, :]), 2)  # NxDx2D
        cov_bottom = tf.concat((tf.linalg.adjoint(p.cov[1, :-1, :, :]), p.cov[0, 1:, :, :]), 2)
        cov = tf.concat((cov_top, cov_bottom), 1)  # Nx2Dx2D

    return mvnquad(eval_func, mu, cov, nghp)
back to top