import tensorflow as tf from .. import mean_functions from .. import covariances from ..expectations import expectation from ..features import InducingFeature, InducingPoints from ..kernels import Kernel from ..probability_distributions import Gaussian from ..util import default_float, default_jitter def uncertain_conditional(Xnew_mu: tf.Tensor, Xnew_var: tf.Tensor, feature: InducingFeature, kernel: Kernel, q_mu, q_sqrt, *, mean_function=None, full_output_cov=False, full_cov=False, white=False): """ Calculates the conditional for uncertain inputs Xnew, p(Xnew) = N(Xnew_mu, Xnew_var). See ``conditional`` documentation for further reference. :param Xnew_mu: mean of the inputs, size [N, D]in :param Xnew_var: covariance matrix of the inputs, size [N, n, n] :param feature: gpflow.InducingFeature object, only InducingPoints is supported :param kernel: gpflow kernel object. :param q_mu: mean inducing points, size [M, Dout] :param q_sqrt: cholesky of the covariance matrix of the inducing points, size [t, M, M] :param full_output_cov: boolean wheter to compute covariance between output dimension. Influences the shape of return value ``fvar``. Default is False :param white: boolean whether to use whitened representation. Default is False. :return fmean, fvar: mean and covariance of the conditional, size ``fmean`` is [N, Dout], size ``fvar`` depends on ``full_output_cov``: if True ``f_var`` is [N, t, t], if False then ``f_var`` is [N, Dout] """ if not isinstance(feature, InducingPoints): raise NotImplementedError if full_cov: # TODO(VD): ``full_cov`` True would return a ``fvar`` of shape [N, N, D, D], # encoding the covariance between input datapoints as well. # This is not implemented as this feature is only used for plotting purposes. raise NotImplementedError pXnew = Gaussian(Xnew_mu, Xnew_var) num_data = Xnew_mu.shape[0] # number of new inputs (N) num_ind = q_mu.shape[0] # number of inducing points (M) num_func = q_mu.shape[1] # output dimension (D) q_sqrt_r = tf.linalg.band_part(q_sqrt, -1, 0) # [D, M, M] eKuf = tf.transpose(expectation(pXnew, (kernel, feature))) # [M, N] (psi1) Kuu = covariances.Kuu(feature, kernel, jitter=default_jitter()) # [M, M] Luu = tf.linalg.cholesky(Kuu) # [M, M] if not white: q_mu = tf.linalg.triangular_solve(Luu, q_mu, lower=True) Luu_tiled = tf.tile(Luu[None, :, :], [num_func, 1, 1]) # remove line once issue 216 is fixed q_sqrt_r = tf.linalg.triangular_solve(Luu_tiled, q_sqrt_r, lower=True) Li_eKuf = tf.linalg.triangular_solve(Luu, eKuf, lower=True) # [M, N] fmean = tf.linalg.matmul(Li_eKuf, q_mu, transpose_a=True) eKff = expectation(pXnew, kernel) # N (psi0) eKuffu = expectation(pXnew, (kernel, feature), (kernel, feature)) # [N, M, M] (psi2) Luu_tiled = tf.tile(Luu[None, :, :], [num_data, 1, 1]) # remove this line, once issue 216 is fixed Li_eKuffu = tf.linalg.triangular_solve(Luu_tiled, eKuffu, lower=True) Li_eKuffu_Lit = tf.linalg.triangular_solve(Luu_tiled, tf.linalg.transpose(Li_eKuffu), lower=True) # [N, M, M] cov = tf.linalg.matmul(q_sqrt_r, q_sqrt_r, transpose_b=True) # [D, M, M] if mean_function is None or isinstance(mean_function, mean_functions.Zero): e_related_to_mean = tf.zeros((num_data, num_func, num_func), dtype=default_float()) else: # Update mean: \mu(x) + m(x) fmean = fmean + expectation(pXnew, mean_function) # Calculate: m(x) m(x)^T + m(x) \mu(x)^T + \mu(x) m(x)^T, # where m(x) is the mean_function and \mu(x) is fmean e_mean_mean = expectation(pXnew, mean_function, mean_function) # [N, D, D] Lit_q_mu = tf.linalg.triangular_solve(Luu, q_mu, adjoint=True) e_mean_Kuf = expectation(pXnew, mean_function, (kernel, feature)) # [N, D, M] # einsum isn't able to infer the rank of e_mean_Kuf, hence we explicitly set the rank of the tensor: e_mean_Kuf = tf.reshape(e_mean_Kuf, [num_data, num_func, num_ind]) e_fmean_mean = tf.einsum("nqm,mz->nqz", e_mean_Kuf, Lit_q_mu) # [N, D, D] e_related_to_mean = e_fmean_mean + tf.linalg.transpose(e_fmean_mean) + e_mean_mean if full_output_cov: fvar = ( tf.linalg.diag(tf.tile((eKff - tf.trace(Li_eKuffu_Lit))[:, None], [1, num_func])) + tf.linalg.diag(tf.einsum("nij,dji->nd", Li_eKuffu_Lit, cov)) + # tf.linalg.diag(tf.trace(tf.linalg.matmul(Li_eKuffu_Lit, cov))) + tf.einsum("ig,nij,jh->ngh", q_mu, Li_eKuffu_Lit, q_mu) - # tf.linalg.matmul(q_mu, tf.linalg.matmul(Li_eKuffu_Lit, q_mu), transpose_a=True) - fmean[:, :, None] * fmean[:, None, :] + e_related_to_mean ) else: fvar = ( (eKff - tf.trace(Li_eKuffu_Lit))[:, None] + tf.einsum("nij,dji->nd", Li_eKuffu_Lit, cov) + tf.einsum("ig,nij,jg->ng", q_mu, Li_eKuffu_Lit, q_mu) - fmean ** 2 + tf.linalg.diag_part(e_related_to_mean) ) return fmean, fvar