Revision 8e64ed039275cf1d3b66856277023c9548317bd8 authored by Jesper Nielsen on 11 April 2022, 09:22:07 UTC, committed by GitHub on 11 April 2022, 09:22:07 UTC
1 parent 3f74b5c
Raw File
uncertain_conditionals.py
# Copyright 2017-2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import tensorflow as tf

from .. import covariances, mean_functions
from ..base import MeanAndVariance
from ..config import default_float, default_jitter
from ..expectations import expectation
from ..inducing_variables import InducingPoints, InducingVariables
from ..kernels import Kernel
from ..probability_distributions import Gaussian


def uncertain_conditional(
    Xnew_mu: tf.Tensor,
    Xnew_var: tf.Tensor,
    inducing_variable: InducingVariables,
    kernel: Kernel,
    q_mu: tf.Tensor,
    q_sqrt: tf.Tensor,
    *,
    mean_function: Optional[mean_functions.MeanFunction] = None,
    full_output_cov: bool = False,
    full_cov: bool = False,
    white: bool = False,
) -> MeanAndVariance:
    """
    Calculates the conditional for uncertain inputs Xnew, p(Xnew) = N(Xnew_mu, Xnew_var).
    See ``conditional`` documentation for further reference.
    :param Xnew_mu: mean of the inputs, size [N, D]in
    :param Xnew_var: covariance matrix of the inputs, size [N, n, n]
    :param inducing_variable: gpflow.InducingVariable object, only InducingPoints is supported
    :param kernel: gpflow kernel object.
    :param q_mu: mean inducing points, size [M, Dout]
    :param q_sqrt: cholesky of the covariance matrix of the inducing points, size [t, M, M]
    :param full_output_cov: boolean wheter to compute covariance between output dimension.
                            Influences the shape of return value ``fvar``. Default is False
    :param white: boolean whether to use whitened representation. Default is False.
    :return fmean, fvar: mean and covariance of the conditional, size ``fmean`` is [N, Dout],
            size ``fvar`` depends on ``full_output_cov``: if True ``f_var`` is [N, t, t],
            if False then ``f_var`` is [N, Dout]
    """

    if not isinstance(inducing_variable, InducingPoints):
        raise NotImplementedError

    if full_cov:
        raise NotImplementedError(
            "uncertain_conditional() currently does not support full_cov=True"
        )

    pXnew = Gaussian(Xnew_mu, Xnew_var)

    num_data = tf.shape(Xnew_mu)[0]  # number of new inputs (N)
    num_ind, num_func = tf.unstack(
        tf.shape(q_mu), num=2, axis=0
    )  # number of inducing points (M), output dimension (D)
    q_sqrt_r = tf.linalg.band_part(q_sqrt, -1, 0)  # [D, M, M]

    eKuf = tf.transpose(expectation(pXnew, (kernel, inducing_variable)))  # [M, N] (psi1)
    Kuu = covariances.Kuu(inducing_variable, kernel, jitter=default_jitter())  # [M, M]
    Luu = tf.linalg.cholesky(Kuu)  # [M, M]

    if not white:
        q_mu = tf.linalg.triangular_solve(Luu, q_mu, lower=True)
        Luu_tiled = tf.tile(
            Luu[None, :, :], [num_func, 1, 1]
        )  # remove line once issue 216 is fixed
        q_sqrt_r = tf.linalg.triangular_solve(Luu_tiled, q_sqrt_r, lower=True)

    Li_eKuf = tf.linalg.triangular_solve(Luu, eKuf, lower=True)  # [M, N]
    fmean = tf.linalg.matmul(Li_eKuf, q_mu, transpose_a=True)

    eKff = expectation(pXnew, kernel)  # N (psi0)
    eKuffu = expectation(
        pXnew, (kernel, inducing_variable), (kernel, inducing_variable)
    )  # [N, M, M] (psi2)
    Luu_tiled = tf.tile(
        Luu[None, :, :], [num_data, 1, 1]
    )  # remove this line, once issue 216 is fixed
    Li_eKuffu = tf.linalg.triangular_solve(Luu_tiled, eKuffu, lower=True)
    Li_eKuffu_Lit = tf.linalg.triangular_solve(
        Luu_tiled, tf.linalg.adjoint(Li_eKuffu), lower=True
    )  # [N, M, M]
    cov = tf.linalg.matmul(q_sqrt_r, q_sqrt_r, transpose_b=True)  # [D, M, M]

    if mean_function is None or isinstance(mean_function, mean_functions.Zero):
        e_related_to_mean = tf.zeros((num_data, num_func, num_func), dtype=default_float())
    else:
        # Update mean: \mu(x) + m(x)
        fmean = fmean + expectation(pXnew, mean_function)

        # Calculate: m(x) m(x)^T + m(x) \mu(x)^T + \mu(x) m(x)^T,
        # where m(x) is the mean_function and \mu(x) is fmean
        e_mean_mean = expectation(pXnew, mean_function, mean_function)  # [N, D, D]
        Lit_q_mu = tf.linalg.triangular_solve(Luu, q_mu, adjoint=True)
        e_mean_Kuf = expectation(pXnew, mean_function, (kernel, inducing_variable))  # [N, D, M]
        # einsum isn't able to infer the rank of e_mean_Kuf, hence we explicitly set the rank of the tensor:
        e_mean_Kuf = tf.reshape(e_mean_Kuf, [num_data, num_func, num_ind])
        e_fmean_mean = tf.einsum("nqm,mz->nqz", e_mean_Kuf, Lit_q_mu)  # [N, D, D]
        e_related_to_mean = e_fmean_mean + tf.linalg.adjoint(e_fmean_mean) + e_mean_mean

    if full_output_cov:
        fvar = (
            tf.linalg.diag(tf.tile((eKff - tf.linalg.trace(Li_eKuffu_Lit))[:, None], [1, num_func]))
            + tf.linalg.diag(tf.einsum("nij,dji->nd", Li_eKuffu_Lit, cov))
            +
            # tf.linalg.diag(tf.linalg.trace(tf.linalg.matmul(Li_eKuffu_Lit, cov))) +
            tf.einsum("ig,nij,jh->ngh", q_mu, Li_eKuffu_Lit, q_mu)
            -
            # tf.linalg.matmul(q_mu, tf.linalg.matmul(Li_eKuffu_Lit, q_mu), transpose_a=True) -
            fmean[:, :, None] * fmean[:, None, :]
            + e_related_to_mean
        )
    else:
        fvar = (
            (eKff - tf.linalg.trace(Li_eKuffu_Lit))[:, None]
            + tf.einsum("nij,dji->nd", Li_eKuffu_Lit, cov)
            + tf.einsum("ig,nij,jg->ng", q_mu, Li_eKuffu_Lit, q_mu)
            - fmean ** 2
            + tf.linalg.diag_part(e_related_to_mean)
        )

    return fmean, fvar
back to top