https://github.com/GPflow/GPflow
Revision 834ed79b351ed88fcbc51bba3977bf347227f7ec authored by James Hensman on 24 March 2020, 18:09:29 UTC, committed by GitHub on 24 March 2020, 18:09:29 UTC
This gives GPflow likelihoods a stronger contract re what input shapes are expected and what shapes are returned. In particular, we should obey something akin to tensorflow_probability’s event-shape/batch-shape/sample-shape. Very little changes for most users, except that some shapes will be asserted. Advanced users will benefit from more shape checks and better defined return shapes for methods attached to likelihoods.

Likelihoods now need to define:
 - self.observation_dim: what is the last dimension of Y supposed to be?
 - self.latent_dim: what is the last dimension of F, F_mu and F_var expected to be?
We’ll check that the dimensions of tensors passed in match up.
Return shapes for all methods will be the broadcast-shape of the tensors passed in, with the last dimension removed.
Example: likelihood.variational_expectations(F_mu, F_var, Y) might take tensors of dimensions [..., 2], [..., 2], [..., 2] and return a tensor of shape [...]

The shape checks are handled by the public methods log_prob, predict_mean_and_var, predict_log_density, and variational_expectations; new likelihoods should implement the corresponding private methods with leading underscore.

## Standard likelihoods
Most likelihoods in GPflow are univariate, and treat columns of Y and F as independent variables. For these observation_dim and latent_dim are None, and we shape-check F and Y on-the-fly for matching last-column dimensions. Note that the return shape contract changes as per that above, and the likelihood methods return the sum over observation dimensions.

## Fancy likelihoods
Likelihoods that depart from the univariate standard include:
SwitchedLikelihood [we’ll check that latent_dim = observation_dim - 1]
MultiClass/Softmax [observation_dim = 1, latent_dim = number of classes (e.g. 10 for MNIST)]
HeteroskedasticGaussian e.g. see GPflow notebook [observation_dim = 1, latent_dim = 2]

Note that this change deprecates Likelihood.predict_density in favour of Likelihood.predict_log_density.
1 parent 0919a92
Raw File
Tip revision: 834ed79b351ed88fcbc51bba3977bf347227f7ec authored by James Hensman on 24 March 2020, 18:09:29 UTC
improve shape robustness in likelihoods (#1334)
Tip revision: 834ed79
kullback_leiblers.py
# Copyright 2016 James Hensman, alexggmatthews
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-

import tensorflow as tf
from .config import default_float, default_jitter
from .covariances.kuus import Kuu
from .inducing_variables import InducingVariables
from .kernels import Kernel
from .utilities import Dispatcher, to_default_float

prior_kl = Dispatcher("prior_kl")


@prior_kl.register(InducingVariables, Kernel, object, object)
def _(inducing_variable, kernel, q_mu, q_sqrt, whiten=False):
    if whiten:
        return gauss_kl(q_mu, q_sqrt, None)
    else:
        K = Kuu(inducing_variable, kernel, jitter=default_jitter())  # [P, M, M] or [M, M]
        return gauss_kl(q_mu, q_sqrt, K)


def gauss_kl(q_mu, q_sqrt, K=None, *, K_cholesky=None):
    """
    Compute the KL divergence KL[q || p] between

          q(x) = N(q_mu, q_sqrt^2)
    and
          p(x) = N(0, K)    if K is not None
          p(x) = N(0, I)    if K is None

    We assume L multiple independent distributions, given by the columns of
    q_mu and the first or last dimension of q_sqrt. Returns the *sum* of the
    divergences.

    q_mu is a matrix ([M, L]), each column contains a mean.

    q_sqrt can be a 3D tensor ([L, M, M]), each matrix within is a lower
        triangular square-root matrix of the covariance of q.
    q_sqrt can be a matrix ([M, L]), each column represents the diagonal of a
        square-root matrix of the covariance of q.

    K is the covariance of p (positive-definite matrix).  The K matrix can be
    passed either directly as `K`, or as its Cholesky factor, `K_cholesky`.  In
    either case, it can be a single matrix [M, M], in which case the sum of the
    L KL divergences is computed by broadcasting, or L different covariances
    [L, M, M].

    Note: if no K matrix is given (both `K` and `K_cholesky` are None),
    `gauss_kl` computes the KL divergence from p(x) = N(0, I) instead.
    """

    if (K is not None) and (K_cholesky is not None):
        raise ValueError(
            "Ambiguous arguments: gauss_kl() must only be passed one of `K` or `K_cholesky`."
        )

    is_white = (K is None) and (K_cholesky is None)
    is_diag = len(q_sqrt.shape) == 2

    shape_constraints = [
        (q_mu, ["M", "L"]),
        (q_sqrt, (["M", "L"] if is_diag else ["L", "M", "M"])),
    ]
    if not is_white:
        if K is not None:
            shape_constraints.append((K, (["L", "M", "M"] if len(K.shape) == 3 else ["M", "M"])))
        else:
            shape_constraints.append(
                (K_cholesky, (["L", "M", "M"] if len(K_cholesky.shape) == 3 else ["M", "M"]))
            )
    tf.debugging.assert_shapes(shape_constraints, message="gauss_kl() arguments")

    M, L = tf.shape(q_mu)[0], tf.shape(q_mu)[1]

    if is_white:
        alpha = q_mu  # [M, L]
    else:
        if K is not None:
            Lp = tf.linalg.cholesky(K)  # [L, M, M] or [M, M]
        elif K_cholesky is not None:
            Lp = K_cholesky  # [L, M, M] or [M, M]

        is_batched = len(Lp.shape) == 3

        q_mu = tf.transpose(q_mu)[:, :, None] if is_batched else q_mu  # [L, M, 1] or [M, L]
        alpha = tf.linalg.triangular_solve(Lp, q_mu, lower=True)  # [L, M, 1] or [M, L]

    if is_diag:
        Lq = Lq_diag = q_sqrt
        Lq_full = tf.linalg.diag(tf.transpose(q_sqrt))  # [L, M, M]
    else:
        Lq = Lq_full = tf.linalg.band_part(q_sqrt, -1, 0)  # force lower triangle # [L, M, M]
        Lq_diag = tf.linalg.diag_part(Lq)  # [M, L]

    # Mahalanobis term: μqᵀ Σp⁻¹ μq
    mahalanobis = tf.reduce_sum(tf.square(alpha))

    # Constant term: - L * M
    constant = -to_default_float(tf.size(q_mu, out_type=tf.int64))

    # Log-determinant of the covariance of q(x):
    logdet_qcov = tf.reduce_sum(tf.math.log(tf.square(Lq_diag)))

    # Trace term: tr(Σp⁻¹ Σq)
    if is_white:
        trace = tf.reduce_sum(tf.square(Lq))
    else:
        if is_diag and not is_batched:
            # K is [M, M] and q_sqrt is [M, L]: fast specialisation
            LpT = tf.transpose(Lp)  # [M, M]
            Lp_inv = tf.linalg.triangular_solve(
                Lp, tf.eye(M, dtype=default_float()), lower=True
            )  # [M, M]
            K_inv = tf.linalg.diag_part(tf.linalg.triangular_solve(LpT, Lp_inv, lower=False))[
                :, None
            ]  # [M, M] -> [M, 1]
            trace = tf.reduce_sum(K_inv * tf.square(q_sqrt))
        else:
            # TODO: broadcast instead of tile when tf allows -- tf2.1 segfaults
            # (https://github.com/tensorflow/tensorflow/issues/37584).
            # See # https://github.com/GPflow/GPflow/issues/1321
            Lp_full = Lp if is_batched else tf.tile(tf.expand_dims(Lp, 0), [L, 1, 1])
            LpiLq = tf.linalg.triangular_solve(Lp_full, Lq_full, lower=True)
            trace = tf.reduce_sum(tf.square(LpiLq))

    twoKL = mahalanobis + constant - logdet_qcov + trace

    # Log-determinant of the covariance of p(x):
    if not is_white:
        log_sqdiag_Lp = tf.math.log(tf.square(tf.linalg.diag_part(Lp)))
        sum_log_sqdiag_Lp = tf.reduce_sum(log_sqdiag_Lp)
        # If K is [L, M, M], num_latent_gps is no longer implicit, no need to multiply the single kernel logdet
        scale = 1.0 if is_batched else to_default_float(L)
        twoKL += scale * sum_log_sqdiag_Lp

    tf.debugging.assert_shapes([(twoKL, ())], message="gauss_kl() return value")  # returns scalar
    return 0.5 * twoKL
back to top