# noqa: F811 import tensorflow as tf from .. import mean_functions as mfn from ..probability_distributions import Gaussian from . import dispatch from .expectations import expectation NoneType = type(None) @dispatch.expectation.register(Gaussian, (mfn.Linear, mfn.Constant), NoneType, NoneType, NoneType) def _E(p, mean, _, __, ___, nghp=None): """ Compute the expectation: _p(X) - m(x) :: Linear, Identity or Constant mean function :return: NxQ """ return mean(p.mu) @dispatch.expectation.register(Gaussian, mfn.Constant, NoneType, mfn.Constant, NoneType) def _E(p, mean1, _, mean2, __, nghp=None): """ Compute the expectation: expectation[n] = _p(x_n) - m1(.), m2(.) :: Constant mean functions :return: NxQ1xQ2 """ return mean1(p.mu)[:, :, None] * mean2(p.mu)[:, None, :] @dispatch.expectation.register(Gaussian, mfn.Constant, NoneType, mfn.MeanFunction, NoneType) def _E(p, mean1, _, mean2, __, nghp=None): """ Compute the expectation: expectation[n] = _p(x_n) - m1(.) :: Constant mean function - m2(.) :: General mean function :return: NxQ1xQ2 """ e_mean2 = expectation(p, mean2) return mean1(p.mu)[:, :, None] * e_mean2[:, None, :] @dispatch.expectation.register(Gaussian, mfn.MeanFunction, NoneType, mfn.Constant, NoneType) def _E(p, mean1, _, mean2, __, nghp=None): """ Compute the expectation: expectation[n] = _p(x_n) - m1(.) :: General mean function - m2(.) :: Constant mean function :return: NxQ1xQ2 """ e_mean1 = expectation(p, mean1) return e_mean1[:, :, None] * mean2(p.mu)[:, None, :] @dispatch.expectation.register(Gaussian, mfn.Identity, NoneType, mfn.Identity, NoneType) def _E(p, mean1, _, mean2, __, nghp=None): """ Compute the expectation: expectation[n] = _p(x_n) - m1(.), m2(.) :: Identity mean functions :return: NxDxD """ return p.cov + (p.mu[:, :, None] * p.mu[:, None, :]) @dispatch.expectation.register(Gaussian, mfn.Identity, NoneType, mfn.Linear, NoneType) def _E(p, mean1, _, mean2, __, nghp=None): """ Compute the expectation: expectation[n] = _p(x_n) - m1(.) :: Identity mean function - m2(.) :: Linear mean function :return: NxDxQ """ N = tf.shape(p.mu)[0] e_xxt = p.cov + (p.mu[:, :, None] * p.mu[:, None, :]) # NxDxD e_xxt_A = tf.linalg.matmul(e_xxt, tf.tile(mean2.A[None, ...], (N, 1, 1))) # NxDxQ e_x_bt = p.mu[:, :, None] * mean2.b[None, None, :] # NxDxQ return e_xxt_A + e_x_bt @dispatch.expectation.register(Gaussian, mfn.Linear, NoneType, mfn.Identity, NoneType) def _E(p, mean1, _, mean2, __, nghp=None): """ Compute the expectation: expectation[n] = _p(x_n) - m1(.) :: Linear mean function - m2(.) :: Identity mean function :return: NxQxD """ N = tf.shape(p.mu)[0] e_xxt = p.cov + (p.mu[:, :, None] * p.mu[:, None, :]) # NxDxD e_A_xxt = tf.linalg.matmul(tf.tile(mean1.A[None, ...], (N, 1, 1)), e_xxt, transpose_a=True) # NxQxD e_b_xt = mean1.b[None, :, None] * p.mu[:, None, :] # NxQxD return e_A_xxt + e_b_xt @dispatch.expectation.register(Gaussian, mfn.Linear, NoneType, mfn.Linear, NoneType) def _E(p, mean1, _, mean2, __, nghp=None): """ Compute the expectation: expectation[n] = _p(x_n) - m1(.), m2(.) :: Linear mean functions :return: NxQ1xQ2 """ e_xxt = p.cov + (p.mu[:, :, None] * p.mu[:, None, :]) # NxDxD e_A1t_xxt_A2 = tf.einsum("iq,nij,jz->nqz", mean1.A, e_xxt, mean2.A) # NxQ1xQ2 e_A1t_x_b2t = tf.einsum("iq,ni,z->nqz", mean1.A, p.mu, mean2.b) # NxQ1xQ2 e_b1_xt_A2 = tf.einsum("q,ni,iz->nqz", mean1.b, p.mu, mean2.A) # NxQ1xQ2 e_b1_b2t = mean1.b[:, None] * mean2.b[None, :] # Q1xQ2 return e_A1t_xxt_A2 + e_A1t_x_b2t + e_b1_xt_A2 + e_b1_b2t