swh:1:snp:93431e0de56bff942fc37a8298daad635afceed0
Tip revision: ff13e612910cfcb6b83f2e8d4cc8627888bffe92 authored by Sergio Diaz on 21 March 2019, 13:41:09 UTC
fixing imports
fixing imports
Tip revision: ff13e61
kuus.py
import tensorflow as tf
from ..features import InducingPoints, Multiscale
from ..kernels import Kernel, RBF
from .dispatch import Kuu
@Kuu.register(InducingPoints, Kernel)
def _Kuu(feat: InducingPoints, kern: Kernel, *, jitter=0.0):
Kzz = kern(feat.Z)
Kzz += jitter * tf.eye(len(feat), dtype=Kzz.dtype)
return Kzz
@Kuu.register(Multiscale, RBF)
def _Kuu(feat: Multiscale, kern: RBF, *, jitter=0.0):
Zmu, Zlen = kern.slice(feat.Z, feat.scales)
idlengthscales2 = tf.square(kern.lengthscales + Zlen)
sc = tf.sqrt(idlengthscales2[None, ...] + idlengthscales2[:, None, ...]
- kern.lengthscales ** 2)
d = feat._cust_square_dist(Zmu, Zmu, sc)
Kzz = kern.variance * tf.exp(-d / 2) * tf.reduce_prod(kern.lengthscales / sc, 2)
Kzz += jitter * tf.eye(len(feat), dtype=Kzz.dtype)
return Kzz