# Copyright 2017 the GPflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import gpflow
from gpflow.kullback_leiblers import gauss_kl
from numpy.testing import assert_almost_equal
import pytest
from gpflow import settings
from gpflow.test_util import session_tf
def squareT(A):
"""
Returns (A Aᵀ)
"""
return A.dot(A.T)
def make_sqrt_data(rng, N, M):
return np.array([np.tril(rng.randn(M, M)) for _ in range(N)]) # N x M x M
def make_K_batch_data(rng, N, M):
K_np = rng.randn(N, M, M)
beye = np.array([np.eye(M) for _ in range(N)])
return .1 * (K_np + np.transpose(K_np, (0, 2, 1))) + beye
class Datum:
M, N = 5, 4
rng = np.random.RandomState(0)
mu_data = rng.randn(M, N) # M x N
K_data = squareT(rng.randn(M, M)) + 1e-6 * np.eye(M) # M x M
I = np.eye(M) # M x M
sqrt_data = make_sqrt_data(rng, N, M) # N x M x M
sqrt_diag_data = rng.randn(M, N) # M x N
K_batch_data = make_K_batch_data(rng, N, M)
@pytest.fixture
def mu(session_tf):
return tf.convert_to_tensor(Datum.mu_data)
@pytest.fixture
def sqrt_diag(session_tf):
return tf.convert_to_tensor(Datum.sqrt_diag_data)
@pytest.fixture
def K(session_tf):
return tf.convert_to_tensor(Datum.K_data)
@pytest.fixture
def K_cholesky(session_tf):
return tf.cholesky(tf.convert_to_tensor(Datum.K_data))
@pytest.fixture
def K_batch(session_tf):
return tf.convert_to_tensor(Datum.K_batch_data)
@pytest.fixture
def sqrt(session_tf):
return tf.convert_to_tensor(Datum.sqrt_data)
@pytest.fixture()
def I(session_tf):
return tf.convert_to_tensor(Datum.I)
@pytest.mark.parametrize('white', [True, False])
def test_diags(session_tf, white, mu, sqrt_diag, K):
"""
The covariance of q(x) can be Cholesky matrices or diagonal matrices.
Here we make sure the behaviours overlap.
"""
# the chols are diagonal matrices, with the same entries as the diag representation.
chol_from_diag = tf.stack([tf.diag(sqrt_diag[:, i]) for i in range(Datum.N)]) # N x M x M
# run
kl_diag = gauss_kl(mu, sqrt_diag, K if white else None)
kl_dense = gauss_kl(mu, chol_from_diag, K if white else None)
np.testing.assert_allclose(kl_diag.eval(), kl_dense.eval())
@pytest.mark.parametrize('diag', [True, False])
def test_kl_k_cholesky(session_tf, mu, sqrt, sqrt_diag, K, K_cholesky, diag):
"""
Test that passing K or K_cholesky yield the same answer
"""
kl_K = gauss_kl(mu, sqrt_diag if diag else sqrt, K=K)
kl_K_chol = gauss_kl(mu, sqrt_diag if diag else sqrt, K_cholesky=K_cholesky)
np.testing.assert_allclose(kl_K.eval(), kl_K_chol.eval())
@pytest.mark.parametrize('diag', [True, False])
def test_whitened(session_tf, diag, mu, sqrt_diag, I):
"""
Check that K=Identity and K=None give same answer
"""
chol_from_diag = tf.stack([tf.diag(sqrt_diag[:, i]) for i in range(Datum.N)]) # N x M x M
s = sqrt_diag if diag else chol_from_diag
kl_white = gauss_kl(mu, s)
kl_nonwhite = gauss_kl(mu, s, I)
np.testing.assert_allclose(kl_white.eval(), kl_nonwhite.eval())
@pytest.mark.parametrize('shared_k', [True, False])
@pytest.mark.parametrize('diag', [True, False])
def test_sumkl_equals_batchkl(session_tf, shared_k, diag, mu,
sqrt, sqrt_diag, K_batch, K):
"""
gauss_kl implicitely performs a sum of KL divergences
This test checks that doing the sum outside of the function is equivalent
For q(X)=prod q(x_l) and p(X)=prod p(x_l), check that sum KL(q(x_l)||p(x_l)) = KL(q(X)||p(X))
Here, q(X) has covariance L x M x M
p(X) has covariance L x M x M ( or M x M )
Here, q(x_i) has covariance 1 x M x M
p(x_i) has covariance M x M
"""
s = sqrt_diag if diag else sqrt
kl_batch = gauss_kl(mu,s,K if shared_k else K_batch)
kl_sum = []
for n in range(Datum.N):
kl_sum.append(gauss_kl(mu[:, n][:,None], # M x 1
sqrt_diag[:, n][:, None] if diag else sqrt[n, :, :][None, :, :], # 1 x M x M or M x 1
K if shared_k else K_batch[n, :, :][None,:,:])) # 1 x M x M or M x M
kl_sum =tf.reduce_sum(kl_sum)
assert_almost_equal(kl_sum.eval(), kl_batch.eval())
def tf_kl_1d(q_mu, q_sigma, p_var=1.0):
p_var = tf.ones_like(q_sigma) if p_var is None else p_var
q_var = tf.square(q_sigma)
kl = 0.5 * (q_var / p_var + tf.square(q_mu) / p_var - 1 + tf.log(p_var / q_var))
return tf.reduce_sum(kl)
@pytest.mark.parametrize('white', [True, False])
def test_oned(session_tf, white, mu, sqrt, K_batch):
"""
Check that the KL divergence matches a 1D by-hand calculation.
"""
m = 0
mu1d = mu[m,:][None,:] # 1 x N
s1d = sqrt[:,m,m][:,None,None] # N x 1 x 1
K1d = K_batch[:,m,m][:,None,None] # N x 1 x 1
kl = gauss_kl(mu1d,s1d,K1d if not white else None)
kl_tf = tf_kl_1d(tf.reshape(mu1d,(-1,)), # N
tf.reshape(s1d,(-1,)), # N
None if white else tf.reshape(K1d,(-1,))) # N
np.testing.assert_allclose(kl.eval(), kl_tf.eval())
def test_unknown_size_inputs(session_tf):
"""
Test for #725 and #734. When the shape of the Gaussian's mean had at least
one unknown parameter, `gauss_kl` would blow up. This happened because
`tf.size` can only output types `tf.int32` or `tf.int64`.
"""
mu_ph = tf.placeholder(settings.float_type, [None, None])
sqrt_ph = tf.placeholder(settings.float_type, [None, None, None])
mu = np.ones([1, 4], dtype=settings.float_type)
sqrt = np.ones([4, 1, 1], dtype=settings.float_type)
feed_dict = {mu_ph: mu, sqrt_ph: sqrt}
known_shape_tf = gauss_kl(*map(tf.constant, [mu, sqrt]))
unknown_shape_tf = gauss_kl(mu_ph, sqrt_ph)
known_shape = session_tf.run(known_shape_tf)
unknown_shape = session_tf.run(unknown_shape_tf, feed_dict=feed_dict)
np.testing.assert_allclose(known_shape, unknown_shape)
if __name__ == "__main__":
tf.test.main()