Raw File
kufs.py
import tensorflow as tf
from ..features import InducingPoints, Multiscale
from ..kernels import Kernel, RBF
from .dispatch import Kuf


@Kuf.register(InducingPoints, Kernel, object)
def _Kuf(feat: InducingPoints, kern: Kernel, Xnew: tf.Tensor):
    return kern(feat.Z(), Xnew)


@Kuf.register(Multiscale, RBF, object)
def _Kuf(feat: Multiscale, kern: RBF, Xnew):
    Xnew, _ = kern.slice(Xnew, None)
    Zmu, Zlen = kern.slice(feat.Z(), feat.scales())
    idlengthscales = kern.lengthscales() + Zlen
    d = feat._cust_square_dist(Xnew, Zmu, idlengthscales)
    lengthscales = tf.reduce_prod(kern.lengthscales() / idlengthscales, 1)
    lengthscales = tf.reshape(lengthscales, (1, -1))
    return tf.transpose(kern.variance() * tf.exp(-d / 2) * lengthscales)
back to top