https://github.com/GPflow/GPflow
Tip revision: 646ca7a0dc8f7da7afe5f5d958fefe5abebece65 authored by ST John on 15 March 2021, 11:16:47 UTC
format
format
Tip revision: 646ca7a
posterior.py
# Copyright 2016-2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Tuple
import numpy as np
import tensorflow as tf
from . import covariances, kernels
from .base import Module, Parameter
from .conditionals.util import expand_independent_outputs, mix_latent_gp
from .config import default_float, default_jitter
from .inducing_variables import (
FallbackSeparateIndependentInducingVariables,
FallbackSharedIndependentInducingVariables,
InducingPoints,
InducingVariables,
SeparateIndependentInducingVariables,
SharedIndependentInducingVariables,
)
from .models.model import MeanAndVariance
from .utilities import Dispatcher
class DiagNormal(Module):
def __init__(self, q_mu, q_sqrt):
self.q_mu = q_mu # [M, L]
self.q_sqrt = q_sqrt # [M, L]
class MvnNormal(Module):
def __init__(self, q_mu, q_sqrt):
self.q_mu = q_mu # [M, L]
self.q_sqrt = q_sqrt # [L, M, M], lower-triangular
class BasePosterior(Module, ABC):
def __init__(self, kernel, inducing_variable, q_mu, q_sqrt, whiten=True, mean_function=None):
self.inducing_variable = inducing_variable
self.kernel = kernel
self.mean_function = mean_function
self.whiten = whiten
if len(q_sqrt.shape) == 2: # q_diag
self.q_dist = DiagNormal(q_mu, q_sqrt)
else:
self.q_dist = MvnNormal(q_mu, q_sqrt)
self.update_cache() # populates or updates self.alpha and self.Qinv
def _precompute(self):
Kuu = covariances.Kuu(
self.inducing_variable, self.kernel, jitter=default_jitter()
) # [(R), M, M]
q_mu = self.q_dist.q_mu
if Kuu.shape.ndims == 4:
ML = tf.reduce_prod(tf.shape(Kuu)[:2])
Kuu = tf.reshape(Kuu, [ML, ML])
if Kuu.shape.ndims == 3:
q_mu = tf.linalg.adjoint(self.q_dist.q_mu)[..., None] # [..., R, M, 1]
L = tf.linalg.cholesky(Kuu)
if not self.whiten:
# alpha = Kuu⁻¹ q_mu
alpha = tf.linalg.cholesky_solve(L, q_mu)
else:
# alpha = L⁻ᵀ q_mu
alpha = tf.linalg.triangular_solve(L, q_mu, adjoint=True)
# predictive mean = Kfu alpha
# predictive variance = Kff - Kfu Qinv Kuf
# S = q_sqrt q_sqrtᵀ
if not self.whiten:
# Qinv = Kuu⁻¹ - Kuu⁻¹ S Kuu⁻¹
# = Kuu⁻¹ - L⁻ᵀ L⁻¹ S L⁻ᵀ L⁻¹
# = L⁻ᵀ (I - L⁻¹ S L⁻ᵀ) L⁻¹
# = L⁻ᵀ B L⁻¹
if isinstance(self.q_dist, DiagNormal):
q_sqrt = tf.linalg.diag(tf.linalg.adjoint(self.q_dist.q_sqrt))
else:
q_sqrt = self.q_dist.q_sqrt
Linv_qsqrt = tf.linalg.triangular_solve(L, q_sqrt)
Linv_cov_u_LinvT = tf.matmul(Linv_qsqrt, Linv_qsqrt, transpose_b=True)
else:
if isinstance(self.q_dist, DiagNormal):
Linv_cov_u_LinvT = tf.linalg.diag(tf.linalg.adjoint(self.q_dist.q_sqrt ** 2))
else:
q_sqrt = self.q_dist.q_sqrt
Linv_cov_u_LinvT = tf.matmul(q_sqrt, q_sqrt, transpose_b=True)
# Qinv = Kuu⁻¹ - L⁻ᵀ S L⁻¹
# Linv = (L⁻¹ I) = solve(L, I)
# Kinv = Linvᵀ @ Linv
I = tf.eye(tf.shape(Linv_cov_u_LinvT)[-1], dtype=Linv_cov_u_LinvT.dtype)
B = I - Linv_cov_u_LinvT
LinvT_B = tf.linalg.triangular_solve(L, B, adjoint=True)
B_Linv = tf.linalg.adjoint(LinvT_B)
Qinv = tf.linalg.triangular_solve(L, B_Linv, adjoint=True)
return alpha, Qinv
def update_cache(self):
self.alpha, self.Qinv = self._precompute()
def freeze(self):
alpha, Qinv = self._precompute()
self.alpha = Parameter(alpha, trainable=False)
self.Qinv = Parameter(Qinv, trainable=False)
def update_cache_with_variables(self):
alpha, Qinv = self._precompute()
if isinstance(self.alpha, Parameter) and isinstance(self.Qinv, Parameter):
self.alpha.assign(alpha)
self.Qinv.assign(Qinv)
else:
self.alpha = Parameter(alpha, trainable=False)
self.Qinv = Parameter(Qinv, trainable=False)
@abstractmethod
def predict_f(
self, Xnew, full_cov: bool = False, full_output_cov: bool = False
) -> MeanAndVariance:
raise NotImplementedError()
class IndependentPosterior(BasePosterior):
def _post_process_mean_and_cov(self, mean, cov, full_cov, full_output_cov):
return mean, expand_independent_outputs(cov, full_cov, full_output_cov)
def predict_f(
self, Xnew, full_cov: bool = False, full_output_cov: bool = False
) -> MeanAndVariance:
# Qinv: [L, M, M]
# alpha: [M, L]
Kuf, Knn = _get_kernels(
Xnew, self.inducing_variable, self.kernel, full_cov, full_output_cov
)
N = tf.shape(Xnew)[0]
K = tf.shape(Kuf)[-1] // N
mean = tf.matmul(Kuf, self.alpha, transpose_a=True)
if Kuf.shape.ndims == 3:
mean = tf.linalg.adjoint(tf.squeeze(mean, axis=-1))
if full_cov:
Kfu_Qinv_Kuf = tf.matmul(Kuf, self.Qinv @ Kuf, transpose_a=True)
cov = Knn - Kfu_Qinv_Kuf
else:
# [Aᵀ B]_ij = Aᵀ_ik B_kj = A_ki B_kj
# TODO check whether einsum is faster now?
Kfu_Qinv_Kuf = tf.reduce_sum(Kuf * tf.matmul(self.Qinv, Kuf), axis=-2)
cov = Knn - Kfu_Qinv_Kuf
cov = tf.linalg.adjoint(cov)
mean, cov = self._post_process_mean_and_cov(mean, cov, full_cov, full_output_cov)
return mean + self.mean_function(Xnew), cov
class LinearCoregionalizationPosterior(IndependentPosterior):
def _post_process_mean_and_cov(self, mean, cov, full_cov, full_output_cov):
cov = expand_independent_outputs(cov, full_cov, full_output_cov=False)
mean, cov = mix_latent_gp(self.kernel.W, mean, cov, full_cov, full_output_cov)
return mean, cov
class FullyCorrelatedPosterior(BasePosterior):
def predict_f(
self, Xnew, full_cov: bool = False, full_output_cov: bool = False
) -> MeanAndVariance:
# TODO: this assumes that Xnew has shape [N, D] and no leading dims
# Qinv: [L, M, M]
# alpha: [M, L]
Kuf = covariances.Kuf(self.inducing_variable, self.kernel, Xnew)
assert Kuf.shape.ndims == 4
M, L, N, K = tf.unstack(tf.shape(Kuf), num=Kuf.shape.ndims, axis=0)
Kuf = tf.reshape(Kuf, (M * L, N * K))
Knn = self.kernel(Xnew, full_cov=full_cov, full_output_cov=full_output_cov)
# full_cov=True and full_output_cov=True: [N, P, N, P]
# full_cov=True and full_output_cov=False: [P, N, N]
# full_cov=False and full_output_cov=True: [N, P, P]
# full_cov=False and full_output_cov=False: [N, P]
if full_cov == full_output_cov:
new_shape = (N * K, N * K) if full_cov else (N * K,)
Knn = tf.reshape(Knn, new_shape)
N = tf.shape(Xnew)[0]
K = tf.shape(Kuf)[-1] // N
mean = tf.matmul(Kuf, self.alpha, transpose_a=True)
if Kuf.shape.ndims == 3:
mean = tf.linalg.adjoint(tf.squeeze(mean, axis=-1))
if not full_cov and not full_output_cov:
# fully diagonal case in both inputs and outputs
# [Aᵀ B]_ij = Aᵀ_ik B_kj = A_ki B_kj
# TODO check whether einsum is faster now?
Kfu_Qinv_Kuf = tf.reduce_sum(Kuf * tf.matmul(self.Qinv, Kuf), axis=-2)
else:
Kfu_Qinv_Kuf = tf.matmul(Kuf, self.Qinv @ Kuf, transpose_a=True)
if not (full_cov and full_output_cov):
# diagonal in either inputs or outputs
new_shape = tf.concat([tf.shape(Kfu_Qinv_Kuf)[:-2], (N, K, N, K)], axis=0)
Kfu_Qinv_Kuf = tf.reshape(Kfu_Qinv_Kuf, new_shape)
if full_cov:
# diagonal in outputs: move outputs to end
tmp = tf.linalg.diag_part(tf.einsum("...ijkl->...ikjl", Kfu_Qinv_Kuf))
elif full_output_cov:
# diagonal in inputs: move inputs to end
tmp = tf.linalg.diag_part(tf.einsum("...ijkl->...jlik", Kfu_Qinv_Kuf))
Kfu_Qinv_Kuf = tf.einsum("...ijk->...kij", tmp) # move diagonal dim to [-3]
cov = Knn - Kfu_Qinv_Kuf
if not full_cov and not full_output_cov:
cov = tf.linalg.adjoint(cov)
mean = tf.reshape(mean, (N, K))
if full_cov == full_output_cov:
cov_shape = (N, K, N, K) if full_cov else (N, K)
else:
cov_shape = (K, N, N) if full_cov else (N, K, K)
cov = tf.reshape(cov, cov_shape)
return mean + self.mean_function(Xnew), cov
def _get_kernels(Xnew, inducing_variable, kernel, full_cov, full_output_cov):
# TODO: this assumes that Xnew has shape [N, D] and no leading dims
Kuf = covariances.Kuf(inducing_variable, kernel, Xnew) # [(R), M, N]
if isinstance(kernel, (kernels.SeparateIndependent, kernels.IndependentLatent)):
# NOTE calling kernel(Xnew, full_cov=full_cov, full_output_cov=False) directly would return
# if full_cov: [P, N, N] -- this is what we want
# else: [N, P] instead of [P, N] as we get from the explicit stack below
Knn = tf.stack([k(Xnew, full_cov=full_cov) for k in kernel.kernels], axis=0)
elif isinstance(kernel, kernels.MultioutputKernel):
# effectively, SharedIndependent path
Knn = kernel.kernel(Xnew, full_cov=full_cov)
# NOTE calling kernel(Xnew, full_cov=full_cov, full_output_cov=False) directly would return
# if full_cov: [P, N, N] instead of [N, N]
# else: [N, P] instead of [N]
else:
# standard ("single-output") kernels
Knn = kernel(Xnew, full_cov=full_cov) # [N, N] if full_cov else [N]
return Kuf, Knn
get_posterior_class = Dispatcher("get_posterior_class")
@get_posterior_class.register(kernels.Kernel, InducingVariables)
def _get_posterior_base_case(kernel, inducing_variable):
# independent single output
return IndependentPosterior
@get_posterior_class.register(kernels.MultioutputKernel, InducingPoints)
def _get_posterior_fully_correlated_mo(kernel, inducing_variable):
return FullyCorrelatedPosterior
@get_posterior_class.register(
(kernels.SharedIndependent, kernels.SeparateIndependent),
(SeparateIndependentInducingVariables, SharedIndependentInducingVariables),
)
def _get_posterior_independent_mo(kernel, inducing_variable):
# independent multi-output
return IndependentPosterior
@get_posterior_class.register(
kernels.IndependentLatent,
(FallbackSeparateIndependentInducingVariables, FallbackSharedIndependentInducingVariables),
)
def _get_posterior_independentlatent_mo_fallback(kernel, inducing_variable):
return FullyCorrelatedPosterior # XXX
@get_posterior_class.register(
kernels.LinearCoregionalization,
(SeparateIndependentInducingVariables, SharedIndependentInducingVariables),
)
def _get_posterior_linearcoregionalization_mo_efficient(kernel, inducing_variable):
# Linear mixing---efficient multi-output
return LinearCoregionalizationPosterior
def create_posterior(kernel, inducing_variable, q_mu, q_sqrt, whiten=True, mean_function=None):
posterior_class = get_posterior_class(kernel, inducing_variable)
return posterior_class(kernel, inducing_variable, q_mu, q_sqrt, whiten, mean_function)