swh:1:snp:93431e0de56bff942fc37a8298daad635afceed0
Raw File
Tip revision: ff13e612910cfcb6b83f2e8d4cc8627888bffe92 authored by Sergio Diaz on 21 March 2019, 13:41:09 UTC
fixing imports
Tip revision: ff13e61
kuus.py
import tensorflow as tf

from ..features import InducingPoints, Multiscale
from ..kernels import Kernel, RBF
from .dispatch import Kuu


@Kuu.register(InducingPoints, Kernel)
def _Kuu(feat: InducingPoints, kern: Kernel, *, jitter=0.0):
    Kzz = kern(feat.Z)
    Kzz += jitter * tf.eye(len(feat), dtype=Kzz.dtype)
    return Kzz


@Kuu.register(Multiscale, RBF)
def _Kuu(feat: Multiscale, kern: RBF, *, jitter=0.0):
    Zmu, Zlen = kern.slice(feat.Z, feat.scales)
    idlengthscales2 = tf.square(kern.lengthscales + Zlen)
    sc = tf.sqrt(idlengthscales2[None, ...] + idlengthscales2[:, None, ...]
                 - kern.lengthscales ** 2)
    d = feat._cust_square_dist(Zmu, Zmu, sc)
    Kzz = kern.variance * tf.exp(-d / 2) * tf.reduce_prod(kern.lengthscales / sc, 2)
    Kzz += jitter * tf.eye(len(feat), dtype=Kzz.dtype)
    return Kzz
back to top