Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

Revision b4df039c6fe478297e532720e76d1213022410d5 authored by Jesper Nielsen on 26 October 2022, 08:27:38 UTC, committed by GitHub on 26 October 2022, 08:27:38 UTC
Fix mypy error. (#2009)
1 parent dc84ca2
  • Files
  • Changes
  • 03384ba
  • /
  • gpflow
  • /
  • posteriors.py
Raw File Download
Permalinks

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • revision
  • directory
  • content
revision badge
swh:1:rev:b4df039c6fe478297e532720e76d1213022410d5
directory badge Iframe embedding
swh:1:dir:1bb343ff1415aab0fb8d8187aeb976ea73886095
content badge Iframe embedding
swh:1:cnt:813671f9e671491a44330bc9ba980ac815d13923
Citations

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • revision
  • directory
  • content
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
posteriors.py
# Copyright 2016-2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, cast

import tensorflow as tf
from check_shapes import (
    ErrorContext,
    Shape,
    check_shapes,
    get_shape,
    inherit_check_shapes,
    register_get_shape,
)

from . import covariances, kernels, mean_functions
from .base import MeanAndVariance, Module, RegressionData, TensorType
from .conditionals.util import (
    base_conditional,
    base_conditional_with_lm,
    expand_independent_outputs,
    fully_correlated_conditional,
    independent_interdomain_conditional,
    mix_latent_gp,
    separate_independent_conditional_implementation,
)
from .config import default_float, default_jitter
from .covariances import Kuf, Kuu
from .inducing_variables import (
    FallbackSeparateIndependentInducingVariables,
    FallbackSharedIndependentInducingVariables,
    InducingPoints,
    InducingVariables,
    SeparateIndependentInducingVariables,
    SharedIndependentInducingVariables,
)
from .kernels import Kernel
from .likelihoods import Gaussian
from .mean_functions import MeanFunction
from .utilities import Dispatcher, add_likelihood_noise_cov, assert_params_false
from .utilities.ops import eye


class _QDistribution(Module):
    """
    Base class for our parametrization of q(u) in the `AbstractPosterior`.
    Internal - do not rely on this outside of GPflow.
    """


class _DeltaDist(_QDistribution):
    @check_shapes(
        "q_mu: [M, L]",
    )
    def __init__(self, q_mu: TensorType) -> None:
        self.q_mu = q_mu

    @property
    def q_sqrt(self) -> Optional[tf.Tensor]:
        return None


class _DiagNormal(_QDistribution):
    @check_shapes(
        "q_mu: [M, L]",
        "q_sqrt: [M, L]",
    )
    def __init__(self, q_mu: TensorType, q_sqrt: TensorType) -> None:
        self.q_mu = q_mu
        self.q_sqrt = q_sqrt


class _MvNormal(_QDistribution):
    @check_shapes(
        "q_mu: [M, L]",
        "q_sqrt: [L, M, M]  # lower-triangular",
    )
    def __init__(self, q_mu: TensorType, q_sqrt: TensorType) -> None:
        self.q_mu = q_mu
        self.q_sqrt = q_sqrt


class PrecomputeCacheType(enum.Enum):
    """
    - `PrecomputeCacheType.TENSOR` (or `"tensor"`): Precomputes the cached
      quantities and stores them as tensors (which allows differentiating
      through the prediction). This is the default.
    - `PrecomputeCacheType.VARIABLE` (or `"variable"`): Precomputes the cached
      quantities and stores them as variables, which allows for updating
      their values without changing the compute graph (relevant for AOT
      compilation).
    - `PrecomputeCacheType.NOCACHE` (or `"nocache"` or `None`): Avoids
      immediate cache computation. This is useful for avoiding extraneous
      computations when you only want to call the posterior's
      `fused_predict_f` method.
    """

    TENSOR = "tensor"
    VARIABLE = "variable"
    NOCACHE = "nocache"


@dataclass
class PrecomputedValue:
    value: tf.Tensor
    """
    The precomputed value itself.
    """

    axis_dynamic: Tuple[bool, ...]
    """
    A tuple with one element per dimension of `value`. That element is `True` if that dimension
    of `value` might change size.
    """

    def __post_init__(self) -> None:
        tf.debugging.assert_rank(
            self.value,
            len(self.axis_dynamic),
            "axis_dynamic must have one element per dimension of value.",
        )

    @staticmethod
    @check_shapes(
        "alpha: [M_L_or_L_M_M...]",
        "Qinv: [M_M_or_L_M_M...]",
    )
    def wrap_alpha_Qinv(alpha: TensorType, Qinv: TensorType) -> Tuple["PrecomputedValue", ...]:
        """
        Wraps `alpha` and `Qinv` in `PrecomputedValue`\ s.
        """
        one_dynamic = False
        L_dynamic = False
        M_dynamic = False  # TODO(jesper): Support variable number of inducing points?

        alpha_rank = tf.rank(alpha)
        if alpha_rank == 2:
            alpha_dynamic: Tuple[bool, ...] = (M_dynamic, L_dynamic)
        elif alpha_rank == 3:
            alpha_dynamic = (L_dynamic, M_dynamic, one_dynamic)
        else:
            raise AssertionError(f"Unknown rank of alpha {alpha_rank}.")

        Qinv_rank = tf.rank(Qinv)
        if Qinv_rank == 2:
            Qinv_dynamic: Tuple[bool, ...] = (M_dynamic, M_dynamic)
        elif Qinv_rank == 3:
            Qinv_dynamic = (L_dynamic, M_dynamic, M_dynamic)
        else:
            raise AssertionError(f"Unknown rank of Qinv {Qinv_rank}.")

        return (
            PrecomputedValue(alpha, alpha_dynamic),
            PrecomputedValue(Qinv, Qinv_dynamic),
        )


@register_get_shape(PrecomputedValue)
def get_precomputed_value_shape(shaped: PrecomputedValue, context: ErrorContext) -> Shape:
    return get_shape(shaped.value, context)


def _validate_precompute_cache_type(
    value: Union[None, PrecomputeCacheType, str]
) -> PrecomputeCacheType:
    if value is None:
        return PrecomputeCacheType.NOCACHE
    elif isinstance(value, PrecomputeCacheType):
        return value
    elif isinstance(value, str):
        return PrecomputeCacheType(value.lower())
    else:
        raise ValueError(
            f"{value} is not a valid PrecomputeCacheType."
            " Valid options: 'tensor', 'variable', 'nocache' (or None)."
        )


class AbstractPosterior(Module, ABC):
    @check_shapes(
        "X_data: [N_D_or_M_D_P...]",
    )
    def __init__(
        self,
        kernel: Kernel,
        X_data: Union[tf.Tensor, InducingVariables],
        cache: Optional[Tuple[tf.Tensor, ...]] = None,
        mean_function: Optional[mean_functions.MeanFunction] = None,
    ) -> None:
        """
        Users should use `create_posterior` to create instances of concrete
        subclasses of this AbstractPosterior class instead of calling this
        constructor directly. For `create_posterior` to be able to correctly
        instantiate subclasses, developers need to ensure their subclasses
        don't change the constructor signature.
        """
        super().__init__()

        self.kernel = kernel
        self.X_data = X_data
        self.cache = cache
        self.mean_function = mean_function

        self._precompute_cache: Optional[PrecomputeCacheType] = None

    @check_shapes(
        "Xnew: [batch..., D]",
        "mean: [batch..., Q]",
        "return: [batch..., Q]",
    )
    def _add_mean_function(self, Xnew: TensorType, mean: TensorType) -> tf.Tensor:
        if self.mean_function is None:
            return mean
        else:
            return mean + self.mean_function(Xnew)

    @abstractmethod
    def _precompute(self) -> Tuple[PrecomputedValue, ...]:
        """
        Precompute a cache.

        The result of this method will later be passed to `_conditional_with_precompute` as the
        `cache` argument.
        """

    @check_shapes(
        "Xnew: [batch..., N, D]",
        "return[0]: [batch..., N, P]",
        "return[1]: [batch..., N, P, N, P] if full_cov and full_output_cov",
        "return[1]: [batch..., P, N, N] if full_cov and (not full_output_cov)",
        "return[1]: [batch..., N, P, P] if (not full_cov) and full_output_cov",
        "return[1]: [batch..., N, P] if (not full_cov) and (not full_output_cov)",
    )
    def fused_predict_f(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        """
        Computes predictive mean and (co)variance at Xnew, including mean_function
        Does not make use of caching
        """
        mean, cov = self._conditional_fused(
            Xnew, full_cov=full_cov, full_output_cov=full_output_cov
        )
        return self._add_mean_function(Xnew, mean), cov

    @abstractmethod
    @check_shapes(
        "Xnew: [batch..., N, D]",
        "return[0]: [batch..., N, P]",
        "return[1]: [batch..., N, P, N, P] if full_cov and full_output_cov",
        "return[1]: [batch..., P, N, N] if full_cov and (not full_output_cov)",
        "return[1]: [batch..., N, P, P] if (not full_cov) and full_output_cov",
        "return[1]: [batch..., N, P] if (not full_cov) and (not full_output_cov)",
    )
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        """
        Computes predictive mean and (co)variance at Xnew, *excluding* mean_function
        Does not make use of caching
        """

    @check_shapes(
        "Xnew: [batch..., N, D]",
        "return[0]: [batch..., N, P]",
        "return[1]: [batch..., N, P, N, P] if full_cov and full_output_cov",
        "return[1]: [batch..., P, N, N] if full_cov and (not full_output_cov)",
        "return[1]: [batch..., N, P, P] if (not full_cov) and full_output_cov",
        "return[1]: [batch..., N, P] if (not full_cov) and (not full_output_cov)",
    )
    def predict_f(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        """
        Computes predictive mean and (co)variance at Xnew, including mean_function.
        Relies on precomputed alpha and Qinv (see _precompute method)
        """
        if self.cache is None:
            raise ValueError(
                "Cache has not been precomputed yet. Call update_cache first or use fused_predict_f"
            )
        mean, cov = self._conditional_with_precompute(
            self.cache, Xnew, full_cov=full_cov, full_output_cov=full_output_cov
        )
        return self._add_mean_function(Xnew, mean), cov

    @abstractmethod
    @check_shapes(
        "Xnew: [batch..., N, D]",
        "return[0]: [batch..., N, P]",
        "return[1]: [batch..., N, P, N, P] if full_cov and full_output_cov",
        "return[1]: [batch..., P, N, N] if full_cov and (not full_output_cov)",
        "return[1]: [batch..., N, P, P] if (not full_cov) and full_output_cov",
        "return[1]: [batch..., N, P] if (not full_cov) and (not full_output_cov)",
    )
    def _conditional_with_precompute(
        self,
        cache: Tuple[tf.Tensor, ...],
        Xnew: TensorType,
        full_cov: bool = False,
        full_output_cov: bool = False,
    ) -> MeanAndVariance:
        """
        Computes predictive mean and (co)variance at Xnew, *excluding* mean_function.
        Relies on cached alpha and Qinv.
        """

    def update_cache(self, precompute_cache: Optional[PrecomputeCacheType] = None) -> None:
        """
        Sets the cache depending on the value of `precompute_cache` to a
        `tf.Tensor`, `tf.Variable`, or clears the cache. If `precompute_cache`
        is not given, the setting defaults to the most-recently-used one.
        """
        if precompute_cache is None:
            if self._precompute_cache is None:
                raise ValueError(
                    "You must pass precompute_cache explicitly"
                    " (the cache had not been updated before)."
                )
            precompute_cache = self._precompute_cache
        else:
            self._precompute_cache = precompute_cache

        if precompute_cache is PrecomputeCacheType.NOCACHE:
            self.cache = None

        elif precompute_cache is PrecomputeCacheType.TENSOR:
            self.cache = tuple(c.value for c in self._precompute())

        elif precompute_cache is PrecomputeCacheType.VARIABLE:
            cache = self._precompute()

            if self.cache is not None and all(isinstance(c, tf.Variable) for c in self.cache):
                # re-use existing variables
                for cache_var, c in zip(self.cache, cache):
                    cache_var.assign(c.value)
            else:  # create variables
                shapes = [
                    [None if d else s for d, s in zip(c.axis_dynamic, tf.shape(c.value))]
                    for c in cache
                ]
                self.cache = tuple(
                    tf.Variable(c.value, trainable=False, shape=s) for c, s in zip(cache, shapes)
                )


class GPRPosterior(AbstractPosterior):
    @check_shapes(
        "data[0]: [N, D]",
        "data[1]: [N, Q]",
    )
    def __init__(
        self,
        kernel: Kernel,
        data: RegressionData,
        likelihood: Gaussian,
        mean_function: MeanFunction,
        *,
        precompute_cache: Optional[PrecomputeCacheType],
    ) -> None:
        X, Y = data
        super().__init__(kernel, X, mean_function=mean_function)
        self.Y_data = Y
        self.likelihood = likelihood

        if precompute_cache is not None:
            self.update_cache(precompute_cache)

    @inherit_check_shapes
    def _conditional_with_precompute(
        self,
        cache: Tuple[tf.Tensor, ...],
        Xnew: TensorType,
        full_cov: bool = False,
        full_output_cov: bool = False,
    ) -> MeanAndVariance:
        """
        Computes predictive mean and (co)variance at Xnew, *excluding* mean_function.
        Relies on cached alpha and Qinv.
        """
        assert_params_false(self._conditional_with_precompute, full_output_cov=full_output_cov)
        err, Lm = cache

        Knn = self.kernel(Xnew, full_cov=full_cov)
        Kmn = self.kernel(self.X_data, Xnew)

        return base_conditional_with_lm(
            Kmn=Kmn,
            Lm=Lm,
            Knn=Knn,
            f=err,
            full_cov=full_cov,
            q_sqrt=None,
            white=False,
        )

    @check_shapes(
        "return[0]: [M, D]",
        "return[1]: [M, M]",
    )
    def _precompute(self) -> Tuple[PrecomputedValue, ...]:
        assert self.mean_function is not None
        X_data = cast(tf.Tensor, self.X_data)
        err = self.Y_data - self.mean_function(X_data)

        Kmm = self.kernel(X_data)
        Kmm_plus_s = add_likelihood_noise_cov(Kmm, self.likelihood, X_data)
        Lm = tf.linalg.cholesky(Kmm_plus_s)

        D = err.shape[1]
        M = X_data.shape[0]
        D_dynamic = D is None
        M_dynamic = M is None

        return (
            PrecomputedValue(err, (M_dynamic, D_dynamic)),
            PrecomputedValue(Lm, (M_dynamic, M_dynamic)),
        )

    @inherit_check_shapes
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        """
        Computes predictive mean and (co)variance at Xnew, *excluding* mean_function
        Does not make use of caching
        """
        temp_cache = tuple(c.value for c in self._precompute())
        return self._conditional_with_precompute(temp_cache, Xnew, full_cov, full_output_cov)


class SGPRPosterior(AbstractPosterior):
    """
    This class represents posteriors which can be derived from SGPR
    models to compute faster predictions on unseen points.
    """

    @check_shapes(
        "data[0]: [N, D]",
        "data[1]: [N, Q]",
        "inducing_variable: [M, D, 1]",
    )
    def __init__(
        self,
        kernel: Kernel,
        data: RegressionData,
        inducing_variable: InducingPoints,
        likelihood: Gaussian,
        num_latent_gps: int,
        mean_function: MeanFunction,
        *,
        precompute_cache: Optional[PrecomputeCacheType],
    ) -> None:
        X, Y = data
        super().__init__(kernel, X, mean_function=mean_function)
        self.Y_data = Y
        self.likelihood = likelihood
        self.inducing_variable = inducing_variable
        self.num_latent_gps = num_latent_gps

        if precompute_cache is not None:
            self.update_cache(precompute_cache)

    @inherit_check_shapes
    def _conditional_with_precompute(
        self,
        cache: Tuple[tf.Tensor, ...],
        Xnew: TensorType,
        full_cov: bool = False,
        full_output_cov: bool = False,
    ) -> MeanAndVariance:
        """
        Computes predictive mean and (co)variance at Xnew, *excluding* mean_function.
        Relies on cached alpha and Qinv.
        """
        assert_params_false(self._conditional_with_precompute, full_output_cov=full_output_cov)

        L, LB, c = cache

        Kus = Kuf(self.inducing_variable, self.kernel, Xnew)
        tmp1 = tf.linalg.triangular_solve(L, Kus, lower=True)
        tmp2 = tf.linalg.triangular_solve(LB, tmp1, lower=True)
        mean = tf.linalg.matmul(tmp2, c, transpose_a=True)
        if full_cov:
            var = (
                self.kernel(Xnew)
                + tf.linalg.matmul(tmp2, tmp2, transpose_a=True)
                - tf.linalg.matmul(tmp1, tmp1, transpose_a=True)
            )
            var = tf.tile(var[None, ...], [self.num_latent_gps, 1, 1])  # [P, N, N]
        else:
            var = (
                self.kernel(Xnew, full_cov=False)
                + tf.reduce_sum(tf.square(tmp2), 0)
                - tf.reduce_sum(tf.square(tmp1), 0)
            )
            var = tf.tile(var[:, None], [1, self.num_latent_gps])

        return mean, var

    @check_shapes(
        "return[0]: [M, M]",
        "return[1]: [M, M]",
        "return[2]: [M, D]",
    )
    def _precompute(self) -> Tuple[PrecomputedValue, ...]:
        assert self.mean_function is not None

        X_data = cast(tf.Tensor, self.X_data)
        num_inducing = self.inducing_variable.num_inducing
        err = self.Y_data - self.mean_function(X_data)

        kuf = Kuf(self.inducing_variable, self.kernel, X_data)
        kuu = Kuu(self.inducing_variable, self.kernel, jitter=default_jitter())

        sigma_sq = tf.squeeze(self.likelihood.variance_at(X_data), axis=-1)
        sigma = tf.sqrt(sigma_sq)

        L = tf.linalg.cholesky(kuu)  # cache alpha, qinv
        A = tf.linalg.triangular_solve(L, kuf / sigma, lower=True)
        B = tf.linalg.matmul(A, A, transpose_b=True) + tf.eye(
            num_inducing, dtype=default_float()
        )  # cache qinv
        LB = tf.linalg.cholesky(B)  # cache alpha
        Aerr = tf.linalg.matmul(A, err / sigma[..., None])
        c = tf.linalg.triangular_solve(LB, Aerr, lower=True)

        D = err.shape[1]
        M = X_data.shape[0]
        D_dynamic = D is None
        M_dynamic = M is None

        return (
            PrecomputedValue(L, (M_dynamic, M_dynamic)),
            PrecomputedValue(LB, (M_dynamic, M_dynamic)),
            PrecomputedValue(c, (M_dynamic, D_dynamic)),
        )

    @inherit_check_shapes
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        """
        Compute the mean and variance of the latent function at some new points
        Xnew. Does not make use of caching
        """
        temp_cache = tuple(c.value for c in self._precompute())
        return self._conditional_with_precompute(temp_cache, Xnew, full_cov, full_output_cov)


class VGPPosterior(AbstractPosterior):
    @check_shapes(
        "X: [N, D]",
        "q_mu: [N, P]",
        "q_sqrt: [N_P_or_P_N_N...]",
    )
    def __init__(
        self,
        kernel: Kernel,
        X: tf.Tensor,
        q_mu: tf.Tensor,
        q_sqrt: tf.Tensor,
        mean_function: Optional[mean_functions.MeanFunction] = None,
        white: bool = True,
        *,
        precompute_cache: Optional[PrecomputeCacheType],
    ) -> None:
        super().__init__(kernel, X, mean_function=mean_function)
        self.q_mu = q_mu
        self.q_sqrt = q_sqrt
        self.white = white

        if precompute_cache is not None:
            self.update_cache(precompute_cache)

    @inherit_check_shapes
    def _conditional_with_precompute(
        self,
        cache: Tuple[tf.Tensor, ...],
        Xnew: TensorType,
        full_cov: bool = False,
        full_output_cov: bool = False,
    ) -> MeanAndVariance:
        assert_params_false(self._conditional_with_precompute, full_output_cov=full_output_cov)

        (Lm,) = cache
        Kmn = self.kernel(self.X_data, Xnew)  # [M, ..., N]
        Knn = self.kernel(
            Xnew, full_cov=full_cov
        )  # [..., N] (full_cov = False) or [..., N, N] (True)

        return base_conditional_with_lm(
            Kmn=Kmn,
            Lm=Lm,
            Knn=Knn,
            f=self.q_mu,
            full_cov=full_cov,
            q_sqrt=self.q_sqrt,
            white=self.white,
        )

    @check_shapes(
        "return[0]: [M, M]",
    )
    def _precompute(self) -> Tuple[PrecomputedValue, ...]:
        X_data = cast(tf.Tensor, self.X_data)
        Kmm = self.kernel(X_data) + eye(
            tf.shape(X_data)[-2], value=default_jitter(), dtype=X_data.dtype
        )  # [..., M, M]
        Lm = tf.linalg.cholesky(Kmm)

        M = X_data.shape[0]
        M_dynamic = M is None

        return (PrecomputedValue(Lm, (M_dynamic, M_dynamic)),)

    @inherit_check_shapes
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        temp_cache = tuple(c.value for c in self._precompute())
        return self._conditional_with_precompute(temp_cache, Xnew, full_cov, full_output_cov)


class BasePosterior(AbstractPosterior):
    @check_shapes(
        "inducing_variable: [M, D, broadcast P]",
        "q_mu: [N, P]",
        "q_sqrt: [N_P_or_P_N_N...]",
    )
    def __init__(
        self,
        kernel: Kernel,
        inducing_variable: InducingVariables,
        q_mu: tf.Tensor,
        q_sqrt: tf.Tensor,
        whiten: bool = True,
        mean_function: Optional[mean_functions.MeanFunction] = None,
        *,
        precompute_cache: Optional[PrecomputeCacheType],
    ):

        super().__init__(kernel, inducing_variable, mean_function=mean_function)
        self.whiten = whiten
        self._set_qdist(q_mu, q_sqrt)

        if precompute_cache is not None:
            self.update_cache(precompute_cache)

    @property  # type: ignore[misc]
    @check_shapes(
        "return: [N, P]",
    )
    def q_mu(self) -> tf.Tensor:
        return self._q_dist.q_mu

    @property  # type: ignore[misc]
    @check_shapes(
        "return: [N_P_or_P_N_N...]",
    )
    def q_sqrt(self) -> tf.Tensor:
        return self._q_dist.q_sqrt

    @check_shapes(
        "q_mu: [N, P]",
        "q_sqrt: [N_P_or_P_N_N...]",
    )
    def _set_qdist(self, q_mu: TensorType, q_sqrt: TensorType) -> None:
        if q_sqrt is None:
            self._q_dist = _DeltaDist(q_mu)
        elif len(q_sqrt.shape) == 2:  # q_diag
            self._q_dist = _DiagNormal(q_mu, q_sqrt)
        else:
            self._q_dist = _MvNormal(q_mu, q_sqrt)

    @check_shapes(
        "return[0]: [M_L_or_L_M_M...]",
        "return[1]: [L, M, M]",
    )
    def _precompute(self) -> Tuple[PrecomputedValue, ...]:
        Kuu = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter())  # [(R), M, M]
        q_mu = self._q_dist.q_mu

        if Kuu.shape.ndims == 4:
            ML = tf.reduce_prod(tf.shape(Kuu)[:2])
            Kuu = tf.reshape(Kuu, [ML, ML])
        if Kuu.shape.ndims == 3:
            q_mu = tf.linalg.adjoint(self._q_dist.q_mu)[..., None]  # [..., R, M, 1]
        L = tf.linalg.cholesky(Kuu)

        if not self.whiten:
            # alpha = Kuu⁻¹ q_mu
            alpha = tf.linalg.cholesky_solve(L, q_mu)
        else:
            # alpha = L⁻ᵀ q_mu
            alpha = tf.linalg.triangular_solve(L, q_mu, adjoint=True)
        # predictive mean = Kfu alpha
        # predictive variance = Kff - Kfu Qinv Kuf
        # S = q_sqrt q_sqrtᵀ
        I = tf.eye(tf.shape(L)[-1], dtype=L.dtype)
        if isinstance(self._q_dist, _DeltaDist):
            B = I
        else:
            if not self.whiten:
                # Qinv = Kuu⁻¹ - Kuu⁻¹ S Kuu⁻¹
                #      = Kuu⁻¹ - L⁻ᵀ L⁻¹ S L⁻ᵀ L⁻¹
                #      = L⁻ᵀ (I - L⁻¹ S L⁻ᵀ) L⁻¹
                #      = L⁻ᵀ B L⁻¹
                if isinstance(self._q_dist, _DiagNormal):
                    q_sqrt = tf.linalg.diag(tf.linalg.adjoint(self._q_dist.q_sqrt))
                elif isinstance(self._q_dist, _MvNormal):
                    q_sqrt = self._q_dist.q_sqrt
                Linv_qsqrt = tf.linalg.triangular_solve(L, q_sqrt)
                Linv_cov_u_LinvT = tf.matmul(Linv_qsqrt, Linv_qsqrt, transpose_b=True)
            else:
                if isinstance(self._q_dist, _DiagNormal):
                    Linv_cov_u_LinvT = tf.linalg.diag(tf.linalg.adjoint(self._q_dist.q_sqrt ** 2))
                elif isinstance(self._q_dist, _MvNormal):
                    q_sqrt = self._q_dist.q_sqrt
                    Linv_cov_u_LinvT = tf.matmul(q_sqrt, q_sqrt, transpose_b=True)
                # Qinv = Kuu⁻¹ - L⁻ᵀ S L⁻¹
                # Linv = (L⁻¹ I) = solve(L, I)
                # Kinv = Linvᵀ @ Linv
            B = I - Linv_cov_u_LinvT
        LinvT_B = tf.linalg.triangular_solve(L, B, adjoint=True)
        B_Linv = tf.linalg.adjoint(LinvT_B)
        Qinv = tf.linalg.triangular_solve(L, B_Linv, adjoint=True)

        M, L = tf.unstack(tf.shape(self._q_dist.q_mu), num=2)
        Qinv = tf.broadcast_to(Qinv, [L, M, M])

        return PrecomputedValue.wrap_alpha_Qinv(alpha, Qinv)


class IndependentPosterior(BasePosterior):
    @check_shapes(
        "mean: [batch..., N, P]",
        "cov: [batch..., P, N, N] if full_cov",
        "cov: [batch..., N, P] if not full_cov",
        "return[0]: [batch..., N, P]",
        "return[1]: [batch..., N, P, N, P] if full_cov and full_output_cov",
        "return[1]: [batch..., N, P, P] if (not full_cov) and full_output_cov",
        "return[1]: [batch..., P, N, N] if full_cov and (not full_output_cov)",
        "return[1]: [batch..., N, P] if (not full_cov) and (not full_output_cov)",
    )
    def _post_process_mean_and_cov(
        self, mean: TensorType, cov: TensorType, full_cov: bool, full_output_cov: bool
    ) -> MeanAndVariance:
        return mean, expand_independent_outputs(cov, full_cov, full_output_cov)

    @check_shapes(
        "Xnew: [N, D]",
        "return: [broadcast P, N, N] if full_cov",
        "return: [broadcast P, N] if (not full_cov)",
    )
    def _get_Kff(self, Xnew: TensorType, full_cov: bool) -> tf.Tensor:

        # TODO: this assumes that Xnew has shape [N, D] and no leading dims

        if isinstance(self.kernel, (kernels.SeparateIndependent, kernels.IndependentLatent)):
            # NOTE calling kernel(Xnew, full_cov=full_cov, full_output_cov=False) directly would
            # return
            # if full_cov: [P, N, N] -- this is what we want
            # else: [N, P] instead of [P, N] as we get from the explicit stack below
            Kff = tf.stack([k(Xnew, full_cov=full_cov) for k in self.kernel.kernels], axis=0)
        elif isinstance(self.kernel, kernels.MultioutputKernel):
            # effectively, SharedIndependent path
            Kff = self.kernel.kernel(Xnew, full_cov=full_cov)
            # NOTE calling kernel(Xnew, full_cov=full_cov, full_output_cov=False) directly would
            # return
            # if full_cov: [P, N, N] instead of [N, N]
            # else: [N, P] instead of [N]
        else:
            # standard ("single-output") kernels
            Kff = self.kernel(Xnew, full_cov=full_cov)  # [N, N] if full_cov else [N]

        return Kff

    @inherit_check_shapes
    def _conditional_with_precompute(
        self,
        cache: Tuple[tf.Tensor, ...],
        Xnew: TensorType,
        full_cov: bool = False,
        full_output_cov: bool = False,
    ) -> MeanAndVariance:
        # Qinv: [L, M, M]
        # alpha: [M, L]
        alpha, Qinv = cache

        Kuf = covariances.Kuf(self.X_data, self.kernel, Xnew)  # [(R), M, N]
        Kff = self._get_Kff(Xnew, full_cov)

        mean = tf.matmul(Kuf, alpha, transpose_a=True)
        if Kuf.shape.ndims == 3:
            mean = tf.linalg.adjoint(tf.squeeze(mean, axis=-1))

        if full_cov:
            Kfu_Qinv_Kuf = tf.matmul(Kuf, Qinv @ Kuf, transpose_a=True)
            cov = Kff - Kfu_Qinv_Kuf
        else:
            # [Aᵀ B]_ij = Aᵀ_ik B_kj = A_ki B_kj
            # TODO check whether einsum is faster now?
            Kfu_Qinv_Kuf = tf.reduce_sum(Kuf * tf.matmul(Qinv, Kuf), axis=-2)
            cov = Kff - Kfu_Qinv_Kuf
            cov = tf.linalg.adjoint(cov)

        return self._post_process_mean_and_cov(mean, cov, full_cov, full_output_cov)


class IndependentPosteriorSingleOutput(IndependentPosterior):
    # could almost be the same as IndependentPosteriorMultiOutput ...
    @inherit_check_shapes
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        # same as IndependentPosteriorMultiOutput, Shared~/Shared~ branch, except for following
        # line:
        Knn = self.kernel(Xnew, full_cov=full_cov)

        Kmm = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter())  # [M, M]
        Kmn = covariances.Kuf(self.X_data, self.kernel, Xnew)  # [M, N]

        fmean, fvar = base_conditional(
            Kmn, Kmm, Knn, self.q_mu, full_cov=full_cov, q_sqrt=self.q_sqrt, white=self.whiten
        )  # [N, P],  [P, N, N] or [N, P]
        return self._post_process_mean_and_cov(fmean, fvar, full_cov, full_output_cov)


class IndependentPosteriorMultiOutput(IndependentPosterior):
    @inherit_check_shapes
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        if isinstance(self.X_data, SharedIndependentInducingVariables) and isinstance(
            self.kernel, kernels.SharedIndependent
        ):
            # same as IndependentPosteriorSingleOutput except for following line
            Knn = self.kernel.kernel(Xnew, full_cov=full_cov)
            # we don't call self.kernel() directly as that would do unnecessary tiling

            Kmm = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter())  # [M, M]
            Kmn = covariances.Kuf(self.X_data, self.kernel, Xnew)  # [M, N]

            fmean, fvar = base_conditional(
                Kmn, Kmm, Knn, self.q_mu, full_cov=full_cov, q_sqrt=self.q_sqrt, white=self.whiten
            )  # [N, P],  [P, N, N] or [N, P]
        else:
            # this is the messy thing with tf.map_fn, cleaned up by the
            # st/clean_up_broadcasting_conditionals branch

            # Following are: [P, M, M]  -  [P, M, N]  -  [P, N](x N)
            Kmms = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter())  # [P, M, M]
            Kmns = covariances.Kuf(self.X_data, self.kernel, Xnew)  # [P, M, N]
            if isinstance(self.kernel, kernels.Combination):
                kernel_list = self.kernel.kernels
            else:
                kernel_list = [self.kernel.kernel] * len(self.X_data.inducing_variable_list)
            Knns = tf.stack(
                [k.K(Xnew) if full_cov else k.K_diag(Xnew) for k in kernel_list], axis=0
            )

            fmean, fvar = separate_independent_conditional_implementation(
                Kmns,
                Kmms,
                Knns,
                self.q_mu,
                q_sqrt=self.q_sqrt,
                full_cov=full_cov,
                white=self.whiten,
            )

        return self._post_process_mean_and_cov(fmean, fvar, full_cov, full_output_cov)


class LinearCoregionalizationPosterior(IndependentPosteriorMultiOutput):
    @check_shapes(
        "mean: [batch..., N, L]",
        "cov: [batch..., L, N, N] if full_cov",
        "cov: [batch..., N, L] if not full_cov",
        "return[0]: [batch..., N, P]",
        "return[1]: [batch..., N, P, N, P] if full_cov and full_output_cov",
        "return[1]: [batch..., N, P, P] if (not full_cov) and full_output_cov",
        "return[1]: [batch..., P, N, N] if full_cov and (not full_output_cov)",
        "return[1]: [batch..., N, P] if (not full_cov) and (not full_output_cov)",
    )
    def _post_process_mean_and_cov(
        self, mean: TensorType, cov: TensorType, full_cov: bool, full_output_cov: bool
    ) -> MeanAndVariance:
        cov = expand_independent_outputs(cov, full_cov, full_output_cov=False)
        mean, cov = mix_latent_gp(self.kernel.W, mean, cov, full_cov, full_output_cov)
        return mean, cov


class FullyCorrelatedPosterior(BasePosterior):
    @inherit_check_shapes
    def _conditional_with_precompute(
        self,
        cache: Tuple[tf.Tensor, ...],
        Xnew: TensorType,
        full_cov: bool = False,
        full_output_cov: bool = False,
    ) -> MeanAndVariance:
        # TODO: this assumes that Xnew has shape [N, D] and no leading dims

        # Qinv: [L, M, M]
        # alpha: [M, L]
        alpha, Qinv = cache

        Kuf = covariances.Kuf(self.X_data, self.kernel, Xnew)
        assert Kuf.shape.ndims == 4
        M, L, N, K = tf.unstack(tf.shape(Kuf), num=Kuf.shape.ndims, axis=0)
        Kuf = tf.reshape(Kuf, (M * L, N * K))

        kernel: kernels.MultioutputKernel = self.kernel
        Kff = kernel(Xnew, full_cov=full_cov, full_output_cov=full_output_cov)
        # full_cov=True and full_output_cov=True: [N, P, N, P]
        # full_cov=True and full_output_cov=False: [P, N, N]
        # full_cov=False and full_output_cov=True: [N, P, P]
        # full_cov=False and full_output_cov=False: [N, P]
        if full_cov == full_output_cov:
            new_shape = (N * K, N * K) if full_cov else (N * K,)
            Kff = tf.reshape(Kff, new_shape)

        N = tf.shape(Xnew)[0]
        K = tf.shape(Kuf)[-1] // N

        mean = tf.matmul(Kuf, alpha, transpose_a=True)
        if Kuf.shape.ndims == 3:
            mean = tf.linalg.adjoint(tf.squeeze(mean, axis=-1))

        if not full_cov and not full_output_cov:
            # fully diagonal case in both inputs and outputs
            # [Aᵀ B]_ij = Aᵀ_ik B_kj = A_ki B_kj
            # TODO check whether einsum is faster now?
            Kfu_Qinv_Kuf = tf.reduce_sum(Kuf * tf.matmul(Qinv, Kuf), axis=-2)
        else:
            Kfu_Qinv_Kuf = tf.matmul(Kuf, Qinv @ Kuf, transpose_a=True)
            if not (full_cov and full_output_cov):
                # diagonal in either inputs or outputs
                new_shape = tf.concat([tf.shape(Kfu_Qinv_Kuf)[:-2], (N, K, N, K)], axis=0)
                Kfu_Qinv_Kuf = tf.reshape(Kfu_Qinv_Kuf, new_shape)
                if full_cov:
                    # diagonal in outputs: move outputs to end
                    tmp = tf.linalg.diag_part(tf.einsum("...ijkl->...ikjl", Kfu_Qinv_Kuf))
                elif full_output_cov:
                    # diagonal in inputs: move inputs to end
                    tmp = tf.linalg.diag_part(tf.einsum("...ijkl->...jlik", Kfu_Qinv_Kuf))
                Kfu_Qinv_Kuf = tf.einsum("...ijk->...kij", tmp)  # move diagonal dim to [-3]
        cov = Kff - Kfu_Qinv_Kuf

        if not full_cov and not full_output_cov:
            cov = tf.linalg.adjoint(cov)

        mean = tf.reshape(mean, (N, K))
        if full_cov == full_output_cov:
            cov_shape = (N, K, N, K) if full_cov else (N, K)
        else:
            cov_shape = (K, N, N) if full_cov else (N, K, K)
        cov = tf.reshape(cov, cov_shape)

        return mean, cov

    @inherit_check_shapes
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        Kmm = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter())  # [M, L, M, L]
        Kmn = covariances.Kuf(self.X_data, self.kernel, Xnew)  # [M, L, N, P]
        kernel: kernels.MultioutputKernel = self.kernel
        Knn = kernel(
            Xnew, full_cov=full_cov, full_output_cov=full_output_cov
        )  # [N, P](x N)x P  or  [N, P](x P)

        M, L, N, K = tf.unstack(tf.shape(Kmn), num=Kmn.shape.ndims, axis=0)
        Kmm = tf.reshape(Kmm, (M * L, M * L))

        if full_cov == full_output_cov:
            Kmn = tf.reshape(Kmn, (M * L, N * K))
            Knn = tf.reshape(Knn, (N * K, N * K)) if full_cov else tf.reshape(Knn, (N * K,))
            mean, cov = base_conditional(
                Kmn, Kmm, Knn, self.q_mu, full_cov=full_cov, q_sqrt=self.q_sqrt, white=self.whiten
            )  # [K, 1], [1, K](x NK)
            mean = tf.reshape(mean, (N, K))
            cov = tf.reshape(cov, (N, K, N, K) if full_cov else (N, K))
        else:
            Kmn = tf.reshape(Kmn, (M * L, N, K))
            mean, cov = fully_correlated_conditional(
                Kmn,
                Kmm,
                Knn,
                self.q_mu,
                full_cov=full_cov,
                full_output_cov=full_output_cov,
                q_sqrt=self.q_sqrt,
                white=self.whiten,
            )
        return mean, cov


class FallbackIndependentLatentPosterior(FullyCorrelatedPosterior):  # XXX
    @inherit_check_shapes
    def _conditional_fused(
        self, Xnew: TensorType, full_cov: bool = False, full_output_cov: bool = False
    ) -> MeanAndVariance:
        Kmm = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter())  # [L, M, M]
        Kmn = covariances.Kuf(self.X_data, self.kernel, Xnew)  # [M, L, N, P]
        kernel: kernels.IndependentLatent = self.kernel
        Knn = kernel(
            Xnew, full_cov=full_cov, full_output_cov=full_output_cov
        )  # [N, P](x N)x P  or  [N, P](x P)

        return independent_interdomain_conditional(
            Kmn,
            Kmm,
            Knn,
            self.q_mu,
            full_cov=full_cov,
            full_output_cov=full_output_cov,
            q_sqrt=self.q_sqrt,
            white=self.whiten,
        )


get_posterior_class = Dispatcher("get_posterior_class")


@get_posterior_class.register(kernels.Kernel, InducingVariables)
def _get_posterior_base_case(
    kernel: Kernel, inducing_variable: InducingVariables
) -> Type[BasePosterior]:
    # independent single output
    return IndependentPosteriorSingleOutput


@get_posterior_class.register(kernels.MultioutputKernel, InducingPoints)
def _get_posterior_fully_correlated_mo(
    kernel: Kernel, inducing_variable: InducingVariables
) -> Type[BasePosterior]:
    return FullyCorrelatedPosterior


@get_posterior_class.register(
    (kernels.SharedIndependent, kernels.SeparateIndependent),
    (SeparateIndependentInducingVariables, SharedIndependentInducingVariables),
)
def _get_posterior_independent_mo(
    kernel: Kernel, inducing_variable: InducingVariables
) -> Type[BasePosterior]:
    # independent multi-output
    return IndependentPosteriorMultiOutput


@get_posterior_class.register(
    kernels.IndependentLatent,
    (FallbackSeparateIndependentInducingVariables, FallbackSharedIndependentInducingVariables),
)
def _get_posterior_independentlatent_mo_fallback(
    kernel: Kernel, inducing_variable: InducingVariables
) -> Type[BasePosterior]:
    return FallbackIndependentLatentPosterior


@get_posterior_class.register(
    kernels.LinearCoregionalization,
    (SeparateIndependentInducingVariables, SharedIndependentInducingVariables),
)
def _get_posterior_linearcoregionalization_mo_efficient(
    kernel: Kernel, inducing_variable: InducingVariables
) -> Type[BasePosterior]:
    # Linear mixing---efficient multi-output
    return LinearCoregionalizationPosterior


def create_posterior(
    kernel: Kernel,
    inducing_variable: InducingVariables,
    q_mu: TensorType,
    q_sqrt: TensorType,
    whiten: bool,
    mean_function: Optional[MeanFunction] = None,
    precompute_cache: Union[PrecomputeCacheType, str, None] = PrecomputeCacheType.TENSOR,
) -> BasePosterior:
    posterior_class = get_posterior_class(kernel, inducing_variable)
    precompute_cache = _validate_precompute_cache_type(precompute_cache)
    return posterior_class(  # type: ignore[no-any-return]
        kernel,
        inducing_variable,
        q_mu,
        q_sqrt,
        whiten,
        mean_function,
        precompute_cache=precompute_cache,
    )
The diff you're trying to view is too large. Only the first 1000 changed files have been loaded.
Showing with 0 additions and 0 deletions (0 / 0 diffs computed)
swh spinner

Computing file changes ...

Software Heritage — Copyright (C) 2015–2025, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Contact— JavaScript license information— Web API

back to top