https://github.com/GPflow/GPflow
Raw File
Tip revision: 0c3e151a3442ac8cd0eb85b9674203e75c288184 authored by Artem Artemev on 19 October 2018, 16:42:59 UTC
Resolve merge conflict.
Tip revision: 0c3e151
test_kldiv.py
# Copyright 2017 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf

import gpflow
from gpflow.kullback_leiblers import gauss_kl
from numpy.testing import assert_almost_equal
import pytest
from gpflow import settings
from gpflow.test_util import session_tf

def squareT(A):
    """
    Returns (A Aᵀ)
    """
    return A.dot(A.T)

def make_sqrt_data(rng, N, M):
    return np.array([np.tril(rng.randn(M, M)) for _ in range(N)]) # N x M x M

def make_K_batch_data(rng, N, M):
    K_np = rng.randn(N, M, M)
    beye = np.array([np.eye(M) for _ in range(N)])
    return .1 * (K_np + np.transpose(K_np, (0, 2, 1))) + beye

class Datum:
    M, N = 5, 4
    rng = np.random.RandomState(0)
    mu_data = rng.randn(M, N)  # M x N
    K_data = squareT(rng.randn(M, M)) + 1e-6 * np.eye(M)  # M x M
    I = np.eye(M) # M x M
    sqrt_data = make_sqrt_data(rng, N, M) # N x M x M
    sqrt_diag_data = rng.randn(M, N) # M x N
    K_batch_data = make_K_batch_data(rng, N, M)

@pytest.fixture
def mu(session_tf):
    return tf.convert_to_tensor(Datum.mu_data)

@pytest.fixture
def sqrt_diag(session_tf):
    return tf.convert_to_tensor(Datum.sqrt_diag_data)

@pytest.fixture
def K(session_tf):
    return tf.convert_to_tensor(Datum.K_data)

@pytest.fixture
def K_batch(session_tf):
    return tf.convert_to_tensor(Datum.K_batch_data)

@pytest.fixture
def sqrt(session_tf):
    return tf.convert_to_tensor(Datum.sqrt_data)

@pytest.fixture()
def I(session_tf):
    return tf.convert_to_tensor(Datum.I)

@pytest.mark.parametrize('white', [True, False])
def test_diags(session_tf, white, mu, sqrt_diag, K):
    """
    The covariance of q(x) can be Cholesky matrices or diagonal matrices.
    Here we make sure the behaviours overlap.
    """
    # the chols are diagonal matrices, with the same entries as the diag representation.
    chol_from_diag = tf.stack([tf.diag(sqrt_diag[:, i]) for i in range(Datum.N)]) # N x M x M
    # run
    kl_diag = gauss_kl(mu, sqrt_diag, K if white else None)
    kl_dense = gauss_kl(mu, chol_from_diag, K if white else None)

    np.testing.assert_allclose(kl_diag.eval(), kl_dense.eval())

@pytest.mark.parametrize('diag', [True, False])
def test_whitened(session_tf, diag, mu, sqrt_diag, I):
    """
    Check that K=Identity and K=None give same answer
    """
    chol_from_diag = tf.stack([tf.diag(sqrt_diag[:, i]) for i in range(Datum.N)]) # N x M x M
    s = sqrt_diag if diag else chol_from_diag

    kl_white = gauss_kl(mu, s)
    kl_nonwhite = gauss_kl(mu, s, I)

    np.testing.assert_allclose(kl_white.eval(), kl_nonwhite.eval())

@pytest.mark.parametrize('shared_k', [True, False])
@pytest.mark.parametrize('diag', [True, False])
def test_sumkl_equals_batchkl(session_tf, shared_k, diag, mu,
                              sqrt, sqrt_diag, K_batch, K):
    """
    gauss_kl implicitely performs a sum of KL divergences
    This test checks that doing the sum outside of the function is equivalent
    For q(X)=prod q(x_l) and p(X)=prod p(x_l), check that sum KL(q(x_l)||p(x_l)) = KL(q(X)||p(X))
    Here, q(X) has covariance L x M x M
    p(X) has covariance L x M x M ( or M x M )
    Here, q(x_i) has covariance 1 x M x M
    p(x_i) has covariance M x M
    """
    s = sqrt_diag if diag else sqrt
    kl_batch = gauss_kl(mu,s,K if shared_k else K_batch)
    kl_sum = []
    for n in range(Datum.N):
        kl_sum.append(gauss_kl(mu[:, n][:,None], # M x 1
            sqrt_diag[:, n][:, None] if diag else sqrt[n, :, :][None, :, :], # 1 x M x M or M x 1
            K if shared_k else K_batch[n, :, :][None,:,:])) # 1 x M x M or M x M
    kl_sum =tf.reduce_sum(kl_sum)
    assert_almost_equal(kl_sum.eval(), kl_batch.eval())

def tf_kl_1d(q_mu, q_sigma, p_var=1.0):
    p_var = tf.ones_like(q_sigma) if p_var is None else p_var
    q_var = tf.square(q_sigma)
    kl = 0.5 * (q_var / p_var + tf.square(q_mu) / p_var - 1 + tf.log(p_var / q_var))
    return tf.reduce_sum(kl)

@pytest.mark.parametrize('white', [True, False])
def test_oned(session_tf, white, mu, sqrt, K_batch):
    """
    Check that the KL divergence matches a 1D by-hand calculation.
    """
    m = 0
    mu1d = mu[m,:][None,:] # 1 x N
    s1d = sqrt[:,m,m][:,None,None] # N x 1 x 1
    K1d = K_batch[:,m,m][:,None,None] # N x 1 x 1

    kl = gauss_kl(mu1d,s1d,K1d if not white else None)
    kl_tf = tf_kl_1d(tf.reshape(mu1d,(-1,)), # N
                   tf.reshape(s1d,(-1,)), # N
                   None if white else tf.reshape(K1d,(-1,))) # N
    np.testing.assert_allclose(kl.eval(), kl_tf.eval())


def test_unknown_size_inputs(session_tf):
    """
    Test for #725 and #734. When the shape of the Gaussian's mean had at least
    one unknown parameter, `gauss_kl` would blow up. This happened because
    `tf.size` can only output types `tf.int32` or `tf.int64`.
    """
    mu_ph = tf.placeholder(settings.float_type, [None, None])
    sqrt_ph = tf.placeholder(settings.float_type, [None, None, None])
    mu = np.ones([1, 4], dtype=settings.float_type)
    sqrt = np.ones([4, 1, 1], dtype=settings.float_type)
    
    feed_dict = {mu_ph: mu, sqrt_ph: sqrt}
    known_shape_tf = gauss_kl(*map(tf.constant, [mu, sqrt]))
    unknown_shape_tf = gauss_kl(mu_ph, sqrt_ph)
    
    known_shape = session_tf.run(known_shape_tf)
    unknown_shape = session_tf.run(unknown_shape_tf, feed_dict=feed_dict)
    
    np.testing.assert_allclose(known_shape, unknown_shape)


if __name__ == "__main__":
    tf.test.main()
back to top