import tensorflow as tf
from .. import mean_functions
from .. import covariances
from ..expectations import expectation
from ..features import InducingFeature, InducingPoints
from ..kernels import Kernel
from ..probability_distributions import Gaussian
from ..util import default_float, default_jitter
def uncertain_conditional(Xnew_mu: tf.Tensor,
Xnew_var: tf.Tensor,
feature: InducingFeature,
kernel: Kernel,
q_mu, q_sqrt,
*,
mean_function=None,
full_output_cov=False,
full_cov=False,
white=False):
"""
Calculates the conditional for uncertain inputs Xnew, p(Xnew) = N(Xnew_mu, Xnew_var).
See ``conditional`` documentation for further reference.
:param Xnew_mu: mean of the inputs, size [N, D]in
:param Xnew_var: covariance matrix of the inputs, size [N, n, n]
:param feature: gpflow.InducingFeature object, only InducingPoints is supported
:param kernel: gpflow kernel object.
:param q_mu: mean inducing points, size [M, Dout]
:param q_sqrt: cholesky of the covariance matrix of the inducing points, size [t, M, M]
:param full_output_cov: boolean wheter to compute covariance between output dimension.
Influences the shape of return value ``fvar``. Default is False
:param white: boolean whether to use whitened representation. Default is False.
:return fmean, fvar: mean and covariance of the conditional, size ``fmean`` is [N, Dout],
size ``fvar`` depends on ``full_output_cov``: if True ``f_var`` is [N, t, t],
if False then ``f_var`` is [N, Dout]
"""
if not isinstance(feature, InducingPoints):
raise NotImplementedError
if full_cov:
# TODO(VD): ``full_cov`` True would return a ``fvar`` of shape [N, N, D, D],
# encoding the covariance between input datapoints as well.
# This is not implemented as this feature is only used for plotting purposes.
raise NotImplementedError
pXnew = Gaussian(Xnew_mu, Xnew_var)
num_data = Xnew_mu.shape[0] # number of new inputs (N)
num_ind = q_mu.shape[0] # number of inducing points (M)
num_func = q_mu.shape[1] # output dimension (D)
q_sqrt_r = tf.linalg.band_part(q_sqrt, -1, 0) # [D, M, M]
eKuf = tf.transpose(expectation(pXnew, (kernel, feature))) # [M, N] (psi1)
Kuu = covariances.Kuu(feature, kernel, jitter=default_jitter()) # [M, M]
Luu = tf.linalg.cholesky(Kuu) # [M, M]
if not white:
q_mu = tf.linalg.triangular_solve(Luu, q_mu, lower=True)
Luu_tiled = tf.tile(Luu[None, :, :], [num_func, 1, 1]) # remove line once issue 216 is fixed
q_sqrt_r = tf.linalg.triangular_solve(Luu_tiled, q_sqrt_r, lower=True)
Li_eKuf = tf.linalg.triangular_solve(Luu, eKuf, lower=True) # [M, N]
fmean = tf.linalg.matmul(Li_eKuf, q_mu, transpose_a=True)
eKff = expectation(pXnew, kernel) # N (psi0)
eKuffu = expectation(pXnew, (kernel, feature), (kernel, feature)) # [N, M, M] (psi2)
Luu_tiled = tf.tile(Luu[None, :, :], [num_data, 1, 1]) # remove this line, once issue 216 is fixed
Li_eKuffu = tf.linalg.triangular_solve(Luu_tiled, eKuffu, lower=True)
Li_eKuffu_Lit = tf.linalg.triangular_solve(Luu_tiled, tf.linalg.transpose(Li_eKuffu), lower=True) # [N, M, M]
cov = tf.linalg.matmul(q_sqrt_r, q_sqrt_r, transpose_b=True) # [D, M, M]
if mean_function is None or isinstance(mean_function, mean_functions.Zero):
e_related_to_mean = tf.zeros((num_data, num_func, num_func), dtype=default_float())
else:
# Update mean: \mu(x) + m(x)
fmean = fmean + expectation(pXnew, mean_function)
# Calculate: m(x) m(x)^T + m(x) \mu(x)^T + \mu(x) m(x)^T,
# where m(x) is the mean_function and \mu(x) is fmean
e_mean_mean = expectation(pXnew, mean_function, mean_function) # [N, D, D]
Lit_q_mu = tf.linalg.triangular_solve(Luu, q_mu, adjoint=True)
e_mean_Kuf = expectation(pXnew, mean_function, (kernel, feature)) # [N, D, M]
# einsum isn't able to infer the rank of e_mean_Kuf, hence we explicitly set the rank of the tensor:
e_mean_Kuf = tf.reshape(e_mean_Kuf, [num_data, num_func, num_ind])
e_fmean_mean = tf.einsum("nqm,mz->nqz", e_mean_Kuf, Lit_q_mu) # [N, D, D]
e_related_to_mean = e_fmean_mean + tf.linalg.transpose(e_fmean_mean) + e_mean_mean
if full_output_cov:
fvar = (
tf.linalg.diag(tf.tile((eKff - tf.trace(Li_eKuffu_Lit))[:, None], [1, num_func])) +
tf.linalg.diag(tf.einsum("nij,dji->nd", Li_eKuffu_Lit, cov)) +
# tf.linalg.diag(tf.trace(tf.linalg.matmul(Li_eKuffu_Lit, cov))) +
tf.einsum("ig,nij,jh->ngh", q_mu, Li_eKuffu_Lit, q_mu) -
# tf.linalg.matmul(q_mu, tf.linalg.matmul(Li_eKuffu_Lit, q_mu), transpose_a=True) -
fmean[:, :, None] * fmean[:, None, :] +
e_related_to_mean
)
else:
fvar = (
(eKff - tf.trace(Li_eKuffu_Lit))[:, None] +
tf.einsum("nij,dji->nd", Li_eKuffu_Lit, cov) +
tf.einsum("ig,nij,jg->ng", q_mu, Li_eKuffu_Lit, q_mu) -
fmean ** 2 +
tf.linalg.diag_part(e_related_to_mean)
)
return fmean, fvar