https://github.com/GPflow/GPflow
Raw File
Tip revision: a382def2e4e25861c500974f6168b2bb4fa9bf94 authored by Artem Artemev on 11 November 2017, 21:54:43 UTC
Merge pull request #547 from GPflow/GPflow-1.0-RC
Tip revision: a382def
kullback_leiblers.py
# Copyright 2016 James Hensman, alexggmatthews
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-


import tensorflow as tf

from . import settings
from .decors import name_scope


@name_scope()
def gauss_kl(q_mu, q_sqrt, K=None):
    """
    Compute the KL divergence KL[q || p] between

          q(x) = N(q_mu, q_sqrt^2)
    and
          p(x) = N(0, K)

    We assume N multiple independent distributions, given by the columns of
    q_mu and the last dimension of q_sqrt. Returns the sum of the divergences.

    q_mu is a matrix (M x N), each column contains a mean.

    q_sqrt can be a 3D tensor (M x M x N), each matrix within is a lower
        triangular square-root matrix of the covariance of q.
    q_sqrt can be a matrix (M x N), each column represents the diagonal of a
        square-root matrix of the covariance of q.

    K is a positive definite matrix (M x M): the covariance of p.
    If K is None, compute the KL divergence to p(x) = N(0, I) instead.
    """

    if K is None:
        white = True
        alpha = q_mu
    else:
        white = False
        Lp = tf.cholesky(K)
        alpha = tf.matrix_triangular_solve(Lp, q_mu, lower=True)

    if q_sqrt.get_shape().ndims == 2:
        diag = True
        num_latent = tf.shape(q_sqrt)[1]
        NM = tf.size(q_sqrt)
        Lq = Lq_diag = q_sqrt
    elif q_sqrt.get_shape().ndims == 3:
        diag = False
        num_latent = tf.shape(q_sqrt)[2]
        NM = tf.reduce_prod(tf.shape(q_sqrt)[1:])
        Lq = tf.matrix_band_part(tf.transpose(q_sqrt, (2, 0, 1)), -1, 0)  # force lower triangle
        Lq_diag = tf.matrix_diag_part(Lq)
    else: # pragma: no cover
        raise ValueError("Bad dimension for q_sqrt: {}".format(q_sqrt.get_shape().ndims))

    # Mahalanobis term: μqᵀ Σp⁻¹ μq
    mahalanobis = tf.reduce_sum(tf.square(alpha))

    # Constant term: - N x M
    constant = - tf.cast(NM, settings.tf_float)

    # Log-determinant of the covariance of q(x):
    logdet_qcov = tf.reduce_sum(tf.log(tf.square(Lq_diag)))

    # Trace term: tr(Σp⁻¹ Σq)
    if white:
        trace = tf.reduce_sum(tf.square(Lq))
    else:
        if diag:
            M = tf.shape(Lp)[0]
            Lp_inv = tf.matrix_triangular_solve(
                Lp, tf.eye(M, dtype=settings.tf_float), lower=True)
            K_inv = tf.matrix_triangular_solve(
                tf.transpose(Lp), Lp_inv, lower=False)
            trace = tf.reduce_sum(
                tf.expand_dims(tf.matrix_diag_part(K_inv), 1) * tf.square(q_sqrt))
        else:
            Lp_tiled = tf.tile(tf.expand_dims(Lp, 0), [num_latent, 1, 1])
            LpiLq = tf.matrix_triangular_solve(Lp_tiled, Lq, lower=True)
            trace = tf.reduce_sum(tf.square(LpiLq))

    twoKL = mahalanobis + constant - logdet_qcov + trace

    # Log-determinant of the covariance of p(x):
    if not white:
        log_sqdiag_Lp = tf.log(tf.square(tf.matrix_diag_part(Lp)))
        sum_log_sqdiag_Lp = tf.reduce_sum(log_sqdiag_Lp)
        prior_logdet = tf.cast(num_latent, settings.tf_float) * sum_log_sqdiag_Lp
        twoKL += prior_logdet

    return 0.5 * twoKL

back to top