https://github.com/GPflow/GPflow
Raw File
Tip revision: 2a30405fefb4de21a0951c885cb38a9d295d68d9 authored by Sergio Diaz on 10 September 2019, 16:00:22 UTC
Merge branch 'awav/gpflow-2.0' into sergio_pasc/gpflow-2.0/adapt-sgpmc-and-gpmc
Tip revision: 2a30405
_test_kldiv.py
# Copyright 2017 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf

from gpflow.kullback_leiblers import gauss_kl
from numpy.testing import assert_almost_equal
import pytest

def squareT(A):
    """
    Returns (A Aᵀ)
    """
    return A.dot(A.T)

def make_sqrt_data(rng, N, M):
    return np.array([np.tril(rng.randn(M, M)) for _ in range(N)]) # [N, M, M]

def make_K_batch_data(rng, N, M):
    K_np = rng.randn(N, M, M)
    beye = np.array([np.eye(M) for _ in range(N)])
    return .1 * (K_np + np.transpose(K_np, (0, 2, 1))) + beye

class Datum:
    M, N = 5, 4
    rng = np.random.RandomState(0)
    mu_data = rng.randn(M, N)  # [M, N]
    K_data = squareT(rng.randn(M, M)) + 1e-6 * np.eye(M)  # [M, M]
    I = np.eye(M) # [M, M]
    sqrt_data = make_sqrt_data(rng, N, M) # [N, M, M]
    sqrt_diag_data = rng.randn(M, N) # [M, N]
    K_batch_data = make_K_batch_data(rng, N, M)

@pytest.fixture
def mu(session_tf):
    return tf.convert_to_tensor(Datum.mu_data)


@pytest.fixture
def sqrt_diag(session_tf):
    return tf.convert_to_tensor(Datum.sqrt_diag_data)


@pytest.fixture
def K(session_tf):
    return tf.convert_to_tensor(Datum.K_data)


@pytest.fixture
def K_batch(session_tf):
    return tf.convert_to_tensor(Datum.K_batch_data)


@pytest.fixture
def sqrt(session_tf):
    return tf.convert_to_tensor(Datum.sqrt_data)


@pytest.fixture()
def I(session_tf):
    return tf.convert_to_tensor(Datum.I)


@pytest.mark.parametrize('white', [True, False])
def test_diags(session_tf, white, mu, sqrt_diag, K):
    """
    The covariance of q(x) can be Cholesky matrices or diagonal matrices.
    Here we make sure the behaviours overlap.
    """
    # the chols are diagonal matrices, with the same entries as the diag representation.
    chol_from_diag = tf.stack([tf.linalg.diag(sqrt_diag[:, i]) for i in range(Datum.N)]) # [N, M, M]
    # run
    kl_diag = gauss_kl(mu, sqrt_diag, K if white else None)
    kl_dense = gauss_kl(mu, chol_from_diag, K if white else None)

    np.testing.assert_allclose(kl_diag.eval(), kl_dense.eval())


@pytest.mark.parametrize('diag', [True, False])
def test_whitened(session_tf, diag, mu, sqrt_diag, I):
    """
    Check that K=Identity and K=None give same answer
    """
    chol_from_diag = tf.stack([tf.linalg.diag(sqrt_diag[:, i]) for i in range(Datum.N)]) # [N, M, M]
    s = sqrt_diag if diag else chol_from_diag

    kl_white = gauss_kl(mu, s)
    kl_nonwhite = gauss_kl(mu, s, I)

    np.testing.assert_allclose(kl_white.eval(), kl_nonwhite.eval())


@pytest.mark.parametrize('shared_k', [True, False])
@pytest.mark.parametrize('diag', [True, False])
def test_sumkl_equals_batchkl(session_tf, shared_k, diag, mu,
                              sqrt, sqrt_diag, K_batch, K):
    """
    gauss_kl implicitely performs a sum of KL divergences
    This test checks that doing the sum outside of the function is equivalent
    For q(X)=prod q(x_l) and p(X)=prod p(x_l), check that sum KL(q(x_l)||p(x_l)) = KL(q(X)||p(X))
    Here, q(X) has covariance [L, M, M]
    p(X) has covariance [L, M, M] ( or [M, M] )
    Here, q(x_i) has covariance [1, M, M]
    p(x_i) has covariance [M, M]
    """
    s = sqrt_diag if diag else sqrt
    kl_batch = gauss_kl(mu,s,K if shared_k else K_batch)
    kl_sum = []
    for n in range(Datum.N):
        kl_sum.append(gauss_kl(mu[:, n][:,None], # [M, 1]
            sqrt_diag[:, n][:, None] if diag else sqrt[n, :, :][None, :, :], # [1, M, M] or [M, 1]
            K if shared_k else K_batch[n, :, :][None,:,:])) # [1, M, M] or [M, M]
    kl_sum =tf.reduce_sum(kl_sum)
    assert_almost_equal(kl_sum.eval(), kl_batch.eval())


def tf_kl_1d(q_mu, q_sigma, p_var=1.0):
    p_var = tf.ones_like(q_sigma) if p_var is None else p_var
    q_var = tf.square(q_sigma)
    kl = 0.5 * (q_var / p_var + tf.square(q_mu) / p_var - 1 + tf.math.log(p_var / q_var))
    return tf.reduce_sum(kl)


@pytest.mark.parametrize('white', [True, False])
def test_oned(session_tf, white, mu, sqrt, K_batch):
    """
    Check that the KL divergence matches a 1D by-hand calculation.
    """
    m = 0
    mu1d = mu[m,:][None,:] # [1, N]
    s1d = sqrt[:,m,m][:,None,None] # [N, 1, 1]
    K1d = K_batch[:,m,m][:,None,None] # [N, 1, 1]

    kl = gauss_kl(mu1d,s1d,K1d if not white else None)
    kl_tf = tf_kl_1d(tf.reshape(mu1d,(-1,)), # N
                   tf.reshape(s1d,(-1,)), # N
                   None if white else tf.reshape(K1d,(-1,))) # N
    np.testing.assert_allclose(kl.eval(), kl_tf.eval())


def test_unknown_size_inputs(session_tf):
    """
    Test for #725 and #734. When the shape of the Gaussian's mean had at least
    one unknown parameter, `gauss_kl` would blow up. This happened because
    `tf.size` can only output types `tf.int32` or `tf.int64`.
    """
    mu_ph = tf.placeholder(default_float(), [None, None])
    sqrt_ph = tf.placeholder(default_float(), [None, None, None])
    mu = np.ones([1, 4], dtype=default_float())
    sqrt = np.ones([4, 1, 1], dtype=default_float())

    feed_dict = {mu_ph: mu, sqrt_ph: sqrt}
    known_shape_tf = gauss_kl(*map(tf.constant, [mu, sqrt]))
    unknown_shape_tf = gauss_kl(mu_ph, sqrt_ph)

    known_shape = session_tf.run(known_shape_tf)
    unknown_shape = session_tf.run(unknown_shape_tf, feed_dict=feed_dict)

    np.testing.assert_allclose(known_shape, unknown_shape)


if __name__ == "__main__":
    tf.test.main()
back to top